Commit 62949a1
Changed files (1)
src
llm
src/llm/gemini.py
@@ -35,7 +35,7 @@ HELP = f"""🌠**AI生图**
"""
-async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", *, force_nonstream: bool = False, **kwargs):
+async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", **kwargs):
r"""Get Gemini response.
Args:
@@ -73,8 +73,8 @@ async def gemini_response(client: Client, message: Message, conversations: list[
genconfig |= {"thinking_config": ThinkingConfig(thinking_budget=thinking_budget)}
params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
logger.trace(params)
- if modality == "image" or force_nonstream:
- return await gemini_nonstream(client, message, model_name, params, **kwargs)
+ if modality == "image":
+ return await gemini_nonstream(client, message, model_name, params, clean_marks=True, **kwargs)
return await gemini_stream(client, message, model_name, params, **kwargs)
except Exception as e:
logger.error(e)
@@ -149,10 +149,13 @@ async def gemini_nonstream(
model_name: str,
params: dict,
retry: int = 0,
+ *,
+ clean_marks: bool = False, # useful in image generation
**kwargs,
):
try:
- clean_gemini_sourcemarks(params["contents"])
+ if clean_marks:
+ clean_gemini_sourcemarks(params["contents"])
api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
if kwargs.get("gemini_api_keys"):
api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]