Commit fe90258

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-02 15:36:06
fix(gemini): fallback to `gemini-2.0-flash` when token count is too large
1 parent d933bfd
Changed files (2)
src/llm/gemini.py
@@ -7,7 +7,7 @@ from pathlib import Path
 
 from glom import glom
 from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, ThinkingConfig, Tool, UrlContext
+from google.genai import types
 from loguru import logger
 from PIL import Image
 from pyrogram.client import Client
@@ -73,7 +73,7 @@ async def gemini_response(
         await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
     response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
     thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
-    tools = [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())] if modality == "text" else None
+    tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())] if modality == "text" else None
     # parse config from environment variable
     genconfig = {}
     with contextlib.suppress(Exception):
@@ -90,8 +90,8 @@ async def gemini_response(
             genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
         if thinking_budget is not None and not disable_thinking:
             thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
-            genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
-        params = {"model": model, "conversations": conversations, "config": GenerateContentConfig(**genconfig)}
+            genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
+        params = {"model": model, "conversations": conversations, "config": types.GenerateContentConfig(**genconfig)}
         logger.trace(params)
         if modality == "image":
             return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
@@ -139,13 +139,19 @@ async def gemini_stream(
             await modify_progress(message=init_status_msg, text=last_error, force_update=True)
             return {"error": last_error}
         api_key = kwargs.get("gemini_api_key", api_keys[retry])
-        http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+        http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
         # Construct the request params
         if "contents" not in params and "conversations" in params:  # convert conversations to contents
             params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
         gemini_logging(params["contents"])
+        tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"])  # type: ignore
+        num_tokens = tokens.total_tokens or 0
+        if num_tokens > GEMINI.TEXT_MAX_TOKEN:
+            logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
+            params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
+            params["config"].thinking_config = None
         sent_messages = []
         is_reasoning = False
         is_reasoning_conversation = None  # to  indicate whether it is a reasoning conversation
@@ -260,7 +266,7 @@ async def gemini_nonstream(
         if retry > len(api_keys) - 1:
             return {}
         api_key = kwargs.get("gemini_api_key", api_keys[retry])
-        http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+        http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
         # Construct the request params
@@ -268,6 +274,13 @@ async def gemini_nonstream(
             params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini")
         if clean_marks:
             clean_gemini_sourcemarks(params["contents"])
+        tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"])  # type: ignore
+        num_tokens = tokens.total_tokens or 0
+        if num_tokens > GEMINI.TEXT_MAX_TOKEN:
+            logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
+            params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
+            params["config"].thinking_config = None
+            params["config"].response_modalities = ["TEXT"]
         response = await app.aio.models.generate_content(**params)
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
         res = parse_response(response.model_dump(), append_grounding=append_grounding)
@@ -340,6 +353,14 @@ def gemini_logging(contexts: list):
     msg = ""
     with contextlib.suppress(Exception):
         for item in contexts:
+            if isinstance(item, str):
+                msg += f"{item}\n"
+                continue
+            if isinstance(item, types.File):
+                msg += f"[{item.mime_type}]: {item.name}\n"
+                continue
+            if not isinstance(item, dict):
+                continue
             role = item.get("role", "").upper() or "MODEL"
 
             # Request
@@ -354,10 +375,10 @@ def gemini_logging(contexts: list):
             if item.get("inline_data", ""):
                 msg += f"[{role}]: Blob_Data  "
 
-        logger.debug(f"{msg!r}")
+    logger.debug(f"{msg!r}")
 
 
-def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> ContentUnionDict:
+def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> types.ContentUnionDict:
     r"""(Deprecated) Convert OpenAI context to Gemini format.
 
     Not needed anymore.
@@ -380,16 +401,16 @@ def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> Conte
             ]
         }
     """
-    parts: list[Part] = []
+    parts: list[types.Part] = []
     role = "model" if context["role"] == "assistant" else "user"
     for item in context["content"]:
         if item["type"] == "text":
             if keep_marks:
-                parts.append(Part.from_text(text=item["text"]))
+                parts.append(types.Part.from_text(text=item["text"]))
             else:
-                parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+                parts.append(types.Part.from_text(text=clean_source_marks(item["text"])))
         elif item["type"] == "image_url":
             data = item["image_url"]["url"].split(";base64,")
             mime = data[0].removeprefix("data:")
-            parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+            parts.append(types.Part.from_bytes(mime_type=mime, data=data[1]))
     return {"role": role, "parts": parts}  # type: ignore
src/config.py
@@ -336,6 +336,8 @@ class GEMINI:  # Official Gemini
     TEXT_MODEL_NAME = os.getenv("GEMINI_TEXT_MODEL_NAME", "Gemini-2.5-Flash")
     TEXT_THINKING_BUDGET = os.getenv("GEMINI_TEXT_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
     TEXT_CONFIG = os.getenv("GEMINI_TEXT_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+    TEXT_MAX_TOKEN = int(os.getenv("GEMINI_TEXT_MAX_TOKEN", "250000"))  # 250K
+    TEXT_TOKENS_FALLBACK_MODEL = os.getenv("GEMINI_TEXT_TOKENS_FALLBACK_MODEL", "gemini-2.0-flash")  # model id when the token count is larger than GEMINI.TEXT_MAX_TOKEN
 
     # response modality: image
     IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")