Commit c4b1548

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-12 11:25:46
refactor(asr): refactor gemini asr config
1 parent 9a33290
Changed files (3)
src/asr/gemini_asr.py
@@ -13,7 +13,7 @@ from pydantic import BaseModel
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ASR, GEMINI, TEXT_LENGTH
+from config import GEMINI, TEXT_LENGTH
 from llm.gemini import parse_response
 from llm.hooks import hook_gemini_httpoptions
 from llm.utils import beautify_llm_response
@@ -75,7 +75,7 @@ Notes:
         logger.error(f"[GeminiASR] Failed after {retry} retries")
         return {"error": last_error}
     path = Path(path)
-    api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
+    api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
     transcriptions = ""
     runtime_texts = ""
     sent_messages = []
@@ -84,22 +84,22 @@ Notes:
     if slient:
         status = None
     try:
-        logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
-        http_options = HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY})
+        logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+        http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=random.choice(api_keys), http_options=http_options)
         uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
         logger.debug(uploaded_audio)
         genconfig = {}
         with contextlib.suppress(Exception):
-            genconfig = json.loads(ASR.GEMINI_CONFIG)
+            genconfig = json.loads(GEMINI.ASR_CONFIG)
         genconfig |= {"response_modalities": ["TEXT"]}  # force text response
         genconfig |= {"system_instruction": system_instruction}  # pin system instruction
-        if ASR.GEMINI_THINKING_BUDGET is not None:
-            thinking_budget = min(round(float(ASR.GEMINI_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+        if GEMINI.ASR_THINKING_BUDGET is not None:
+            thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
             genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
         contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
-        params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+        params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
         async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump())
             sentence = resp.get("texts", "")
@@ -163,18 +163,18 @@ async def gemini_nonstream_asr(path: str | Path, voice_format: str, *, prompt: s
     https://ai.google.dev/gemini-api/docs/audio
     """
     path = Path(path)
-    api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
+    api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
     random.shuffle(api_keys)
     res = ""
     for key in api_keys:
         try:
-            logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
-            client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
+            logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+            client = genai.Client(api_key=key, http_options=HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY}))
             uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
             logger.debug(uploaded_audio)
             contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
             response = await client.aio.models.generate_content(
-                model=ASR.GEMINI_MODEL,
+                model=GEMINI.ASR_MODEL,
                 contents=contents,  # type: ignore
                 config=GenerateContentConfig(
                     response_mime_type="application/json",
src/asr/utils.py
@@ -3,7 +3,7 @@
 import random
 from pathlib import Path
 
-from config import ASR
+from config import ASR, GEMINI
 from multimedia import convert_to_audio
 
 
@@ -62,10 +62,10 @@ def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[s
 
 
 def get_gemini_asr_method(duration: float) -> tuple[str, list[str]]:
-    if duration > ASR.GEMINI_MAX_DURATION:
-        return f"无法识别时长超过{ASR.GEMINI_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
-    if not ASR.GEMINI_API_KEY:
-        return "请设置`ASR_GEMINI_API_KEY`环境变量", []
+    if duration > GEMINI.ASR_MAX_DURATION:
+        return f"无法识别时长超过{GEMINI.ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
+    if not GEMINI.API_KEYS:
+        return "请设置`GEMINI_API_KEYS`环境变量", []
     return "gemini", ["aac", "aiff", "flac", "mp3", "oga", "ogg", "opus", "wav"]
 
 
src/config.py
@@ -257,13 +257,7 @@ class ASR:
     MIDDLE_ENGINE = os.getenv("ASR_MIDDLE_ENGINE", "tencent,ali")
     MIDDLE_DURATION = int(os.getenv("ASR_MIDDLE_DURATION", "600"))
     LONG_ENGINE = os.getenv("ASR_LONG_ENGINE", "gemini")
-    GEMINI_BASR_URL = os.getenv("ASR_GEMINI_BASR_URL", "https://generativelanguage.googleapis.com/")
-    GEMINI_API_KEY = os.getenv("ASR_GEMINI_API_KEY", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
-    GEMINI_MAX_DURATION = int(os.getenv("ASR_GEMINI_MAX_DURATION", "34200"))  # 9.5 hour
-    GEMINI_MODEL = os.getenv("ASR_GEMINI_MODEL", "gemini-2.0-flash")
-    GEMINI_PROXY = os.getenv("ASR_GEMINI_PROXY", None)
-    GEMINI_THINKING_BUDGET = os.getenv("ASR_GEMINI_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
-    GEMINI_CONFIG = os.getenv("ASR_GEMINI_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+
     TENCENT_APPID = os.getenv("ASR_TENCENT_APPID", "")
     TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None)  # Banned oversea IP, need a back to China proxy
     TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
@@ -297,3 +291,9 @@ class GEMINI:  # Official Gemini
     IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
     IMG_THINKING_BUDGET = os.getenv("GEMINI_IMG_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
     IMG_CONFIG = os.getenv("GEMINI_IMG_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+
+    # ASR related
+    ASR_MAX_DURATION = int(os.getenv("GEMINI_ASR_MAX_DURATION", "34200"))  # 9.5 hour
+    ASR_MODEL = os.getenv("GEMINI_ASR_MODEL", "gemini-2.5-flash-preview-04-17")
+    ASR_THINKING_BUDGET = os.getenv("GEMINI_ASR_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
+    ASR_CONFIG = os.getenv("GEMINI_ASR_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'