Commit fe90258
src/llm/gemini.py
@@ -7,7 +7,7 @@ from pathlib import Path
from glom import glom
from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, ThinkingConfig, Tool, UrlContext
+from google.genai import types
from loguru import logger
from PIL import Image
from pyrogram.client import Client
@@ -73,7 +73,7 @@ async def gemini_response(
await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
- tools = [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())] if modality == "text" else None
+ tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())] if modality == "text" else None
# parse config from environment variable
genconfig = {}
with contextlib.suppress(Exception):
@@ -90,8 +90,8 @@ async def gemini_response(
genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
if thinking_budget is not None and not disable_thinking:
thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
- params = {"model": model, "conversations": conversations, "config": GenerateContentConfig(**genconfig)}
+ genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
+ params = {"model": model, "conversations": conversations, "config": types.GenerateContentConfig(**genconfig)}
logger.trace(params)
if modality == "image":
return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
@@ -139,13 +139,19 @@ async def gemini_stream(
await modify_progress(message=init_status_msg, text=last_error, force_update=True)
return {"error": last_error}
api_key = kwargs.get("gemini_api_key", api_keys[retry])
- http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
http_options = hook_gemini_httpoptions(http_options, message)
app = genai.Client(api_key=api_key, http_options=http_options)
# Construct the request params
if "contents" not in params and "conversations" in params: # convert conversations to contents
params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
gemini_logging(params["contents"])
+ tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
+ num_tokens = tokens.total_tokens or 0
+ if num_tokens > GEMINI.TEXT_MAX_TOKEN:
+ logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
+ params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
+ params["config"].thinking_config = None
sent_messages = []
is_reasoning = False
is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
@@ -260,7 +266,7 @@ async def gemini_nonstream(
if retry > len(api_keys) - 1:
return {}
api_key = kwargs.get("gemini_api_key", api_keys[retry])
- http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
http_options = hook_gemini_httpoptions(http_options, message)
app = genai.Client(api_key=api_key, http_options=http_options)
# Construct the request params
@@ -268,6 +274,13 @@ async def gemini_nonstream(
params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini")
if clean_marks:
clean_gemini_sourcemarks(params["contents"])
+ tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
+ num_tokens = tokens.total_tokens or 0
+ if num_tokens > GEMINI.TEXT_MAX_TOKEN:
+ logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
+ params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
+ params["config"].thinking_config = None
+ params["config"].response_modalities = ["TEXT"]
response = await app.aio.models.generate_content(**params)
prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
res = parse_response(response.model_dump(), append_grounding=append_grounding)
@@ -340,6 +353,14 @@ def gemini_logging(contexts: list):
msg = ""
with contextlib.suppress(Exception):
for item in contexts:
+ if isinstance(item, str):
+ msg += f"{item}\n"
+ continue
+ if isinstance(item, types.File):
+ msg += f"[{item.mime_type}]: {item.name}\n"
+ continue
+ if not isinstance(item, dict):
+ continue
role = item.get("role", "").upper() or "MODEL"
# Request
@@ -354,10 +375,10 @@ def gemini_logging(contexts: list):
if item.get("inline_data", ""):
msg += f"[{role}]: Blob_Data "
- logger.debug(f"{msg!r}")
+ logger.debug(f"{msg!r}")
-def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> ContentUnionDict:
+def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> types.ContentUnionDict:
r"""(Deprecated) Convert OpenAI context to Gemini format.
Not needed anymore.
@@ -380,16 +401,16 @@ def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> Conte
]
}
"""
- parts: list[Part] = []
+ parts: list[types.Part] = []
role = "model" if context["role"] == "assistant" else "user"
for item in context["content"]:
if item["type"] == "text":
if keep_marks:
- parts.append(Part.from_text(text=item["text"]))
+ parts.append(types.Part.from_text(text=item["text"]))
else:
- parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+ parts.append(types.Part.from_text(text=clean_source_marks(item["text"])))
elif item["type"] == "image_url":
data = item["image_url"]["url"].split(";base64,")
mime = data[0].removeprefix("data:")
- parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+ parts.append(types.Part.from_bytes(mime_type=mime, data=data[1]))
return {"role": role, "parts": parts} # type: ignore
src/config.py
@@ -336,6 +336,8 @@ class GEMINI: # Official Gemini
TEXT_MODEL_NAME = os.getenv("GEMINI_TEXT_MODEL_NAME", "Gemini-2.5-Flash")
TEXT_THINKING_BUDGET = os.getenv("GEMINI_TEXT_THINKING_BUDGET", None) # 0 to disable thinking. DO NOT set this if the model is not a thinking model
TEXT_CONFIG = os.getenv("GEMINI_TEXT_CONFIG", "{}") # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+ TEXT_MAX_TOKEN = int(os.getenv("GEMINI_TEXT_MAX_TOKEN", "250000")) # 250K
+ TEXT_TOKENS_FALLBACK_MODEL = os.getenv("GEMINI_TEXT_TOKENS_FALLBACK_MODEL", "gemini-2.0-flash") # model id when the token count is larger than GEMINI.TEXT_MAX_TOKEN
# response modality: image
IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")