Commit f2ba19e

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-05 09:47:14
feat(gemini): support set `thinking_budget`
1 parent e75ef07
Changed files (3)
src/asr/gemini_asr.py
@@ -5,13 +5,13 @@ from pathlib import Path
 
 from glom import glom
 from google import genai
-from google.genai.types import GenerateContentConfig, HttpOptions, UploadFileConfig
+from google.genai.types import GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
 from loguru import logger
 from pydantic import BaseModel
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ASR, TEXT_LENGTH
+from config import ASR, GEMINI, TEXT_LENGTH
 from llm.gemini import parse_response
 from llm.utils import beautify_llm_response
 from messages.progress import modify_progress
@@ -49,10 +49,13 @@ async def gemini_stream_asr(client: Client, message: Message, path: str | Path,
         app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
         uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
         logger.debug(uploaded_audio)
-        async for chunk in await app.aio.models.generate_content_stream(
-            model=ASR.GEMINI_MODEL,
-            contents=[prompt, uploaded_audio],
-        ):
+        genconfig = {}
+        genconfig |= {"response_modalities": ["TEXT"]}
+        if ASR.GEMINI_THINKING_BUDGET is not None:
+            thinking_budget = min(round(float(ASR.GEMINI_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+            genconfig |= {"thinking_config": ThinkingConfig(thinking_budget=thinking_budget)}
+        params = {"model": ASR.GEMINI_MODEL, "contents": [prompt, uploaded_audio], "config": GenerateContentConfig(**genconfig)}
+        async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump())
             sentence = resp.get("texts", "")
             transcriptions += sentence
src/llm/gemini.py
@@ -7,7 +7,7 @@ from pathlib import Path
 
 from glom import glom
 from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
+from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, ThinkingConfig, Tool
 from loguru import logger
 from PIL import Image
 from pyrogram.client import Client
@@ -48,21 +48,26 @@ async def gemini_response(client: Client, message: Message, conversations: list[
     if not GEMINI.API_KEYS:
         await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
     response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+    thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
     tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
+
     try:
         msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name']}](tg://user?id={info['uid']})**: “{clean_cmd_prefix(info['text'])}”"
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
         contexts = await get_conversation_contexts(client, conversations, ctx_format="gemini")
         gemini_logging(contexts)
-        params = {"model": model, "contents": contexts}
         genconfig = {}
         genconfig |= {"response_modalities": response_modalities}
         if tools:
             genconfig |= {"tools": tools}
         if GEMINI.PREFER_LANG and modality == "text":
             genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}回复"}
-        params |= {"config": GenerateContentConfig(**genconfig)}
+        if thinking_budget is not None:
+            thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
+            genconfig |= {"thinking_config": ThinkingConfig(thinking_budget=thinking_budget)}
+        params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
+        logger.trace(params)
         if modality == "image":
             return await gemini_nonstream(client, message, model_name, params, **kwargs)
         return await gemini_stream(client, message, model_name, params, **kwargs)
src/config.py
@@ -252,6 +252,7 @@ class ASR:
     GEMINI_MAX_DURATION = int(os.getenv("ASR_GEMINI_MAX_DURATION", "34200"))  # 9.5 hour
     GEMINI_MODEL = os.getenv("ASR_GEMINI_MODEL", "gemini-2.0-flash")
     GEMINI_PROXY = os.getenv("ASR_GEMINI_PROXY", None)
+    GEMINI_THINKING_BUDGET = os.getenv("ASR_GEMINI_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
     TENCENT_APPID = os.getenv("ASR_TENCENT_APPID", "")
     TENCENT_MAX_DURATION = int(os.getenv("ASR_TENCENT_MAX_DURATION", "3600"))  # 1 hour
     TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None)  # Banned oversea IP, need a back to China proxy
@@ -265,11 +266,14 @@ class GEMINI:  # Official Gemini
     API_KEYS = os.getenv("GEMINI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     PROXY = os.getenv("GEMINI_PROXY", None)
     PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "")  # Set a prefer response language for Gemini
+    MAX_THINKING_BUDGET = int(os.getenv("GEMINI_MAX_THINKING_BUDGET", "24576"))  # 24K
 
     # response modality: text
     TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-pro-exp-03-25")
     TEXT_MODEL_NAME = os.getenv("GEMINI_TEXT_MODEL_NAME", "Gemini-2.5-Pro")
+    TEXT_THINKING_BUDGET = os.getenv("GEMINI_TEXT_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
 
     # response modality: image
     IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")
     IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
+    IMG_THINKING_BUDGET = os.getenv("GEMINI_IMG_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model