Commit 4eadac6

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-12 19:09:16
refactor(asr): refactor `gemini_stream_asr`
1 parent c4b1548
src/asr/gemini_asr.py
@@ -13,12 +13,11 @@ from pydantic import BaseModel
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import GEMINI, TEXT_LENGTH
-from llm.gemini import parse_response
+from config import GEMINI
+from llm.gemini import gemini_stream
 from llm.hooks import hook_gemini_httpoptions
-from llm.utils import beautify_llm_response
+from llm.utils import shuffle_keys
 from messages.progress import modify_progress
-from messages.utils import blockquote, count_without_entities, smart_split
 
 
 async def gemini_stream_asr(
@@ -29,9 +28,6 @@ async def gemini_stream_asr(
     prompt: str = "请转录这段音频",
     *,
     slient: bool = False,
-    retry: int = 0,
-    max_retry: int = 2,
-    last_error: str = "",
     **kwargs,
 ) -> dict:
     """Gemini stream ASR.
@@ -71,64 +67,41 @@ Example-2:
 Notes:
 - Focus on accuracy in capturing both the timing and the spoken content.
 - Maintain consistent formatting to ensure clarity and readability."""
-    if retry > max_retry:
-        logger.error(f"[GeminiASR] Failed after {retry} retries")
-        return {"error": last_error}
     path = Path(path)
-    api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
-    transcriptions = ""
-    runtime_texts = ""
-    sent_messages = []
-    if status := kwargs.get("progress"):
-        sent_messages.append(status)
-    if slient:
-        status = None
-    try:
-        logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
-        http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
-        http_options = hook_gemini_httpoptions(http_options, message)
-        app = genai.Client(api_key=random.choice(api_keys), http_options=http_options)
-        uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
-        logger.debug(uploaded_audio)
-        genconfig = {}
-        with contextlib.suppress(Exception):
-            genconfig = json.loads(GEMINI.ASR_CONFIG)
-        genconfig |= {"response_modalities": ["TEXT"]}  # force text response
-        genconfig |= {"system_instruction": system_instruction}  # pin system instruction
-        if GEMINI.ASR_THINKING_BUDGET is not None:
-            thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
-            genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
-        contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
-        params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
-        async for chunk in await app.aio.models.generate_content_stream(**params):
-            resp = parse_response(chunk.model_dump())
-            sentence = resp.get("texts", "")
-            transcriptions += sentence
-            runtime_texts += sentence
-            runtime_texts = beautify_llm_response(runtime_texts)
-            if await count_without_entities(runtime_texts) <= TEXT_LENGTH:
-                if len(runtime_texts) > 5:  # start response if sentence is not empty
-                    await modify_progress(message=status, text=runtime_texts, detail_progress=True, ttl=10)
-            else:  # transcriptions is too long, split it into multiple messages
-                parts = await smart_split(runtime_texts)
-                await modify_progress(message=status, text=blockquote(parts[0]), force_update=True)  # force send the first part
-                runtime_texts = parts[-1]  # keep the last part
-                if not slient:
-                    status = await client.send_message(message.chat.id, runtime_texts)  # the new message
-                    sent_messages.append(status)
-
-        # all chunks are processed
-        await modify_progress(message=status, text=blockquote(beautify_llm_response(runtime_texts)), force_update=True)
-        if uploaded_audio.name:  # delete file once finished
-            await app.aio.files.delete(name=uploaded_audio.name)
-    except Exception as e:
-        logger.error(e)
-        with contextlib.suppress(Exception):
-            [await modify_progress(msg, del_status=True) for msg in sent_messages]
-            if "uploaded_audio" in locals() and uploaded_audio.name:
-                await app.aio.files.delete(name=uploaded_audio.name)
-        return await gemini_stream_asr(client, message, path, voice_format, prompt, slient=slient, retry=retry + 1, max_retry=max_retry, last_error=str(e))
-    return {"texts": transcriptions, "sent_messages": sent_messages}
+    status = None if slient else kwargs.get("progress")
+    api_keys = shuffle_keys(GEMINI.API_KEYS)
+    for api_key in api_keys.split(","):
+        try:
+            logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+            http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
+            http_options = hook_gemini_httpoptions(http_options, message)
+            app = genai.Client(api_key=api_key, http_options=http_options)
+            uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+            genconfig = {}
+            with contextlib.suppress(Exception):
+                genconfig = json.loads(GEMINI.ASR_CONFIG)
+            genconfig |= {"response_modalities": ["TEXT"]}  # force text response
+            genconfig |= {"system_instruction": system_instruction}  # pin system instruction
+            if GEMINI.ASR_THINKING_BUDGET is not None:
+                thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+                genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+            params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+            res = await gemini_stream(client, message, model_name="ASR", params=params, prefix="", slient=slient, max_retry=0, gemini_api_key=api_key, **kwargs)
+            if res.get("error"):
+                continue
+            sent_messages = res.get("sent_messages", [])
+            break
+        except Exception as e:
+            logger.error(e)
+            with contextlib.suppress(Exception):
+                [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
+        finally:
+            with contextlib.suppress(Exception):
+                if "uploaded_audio" in locals() and uploaded_audio.name:
+                    await app.aio.files.delete(name=uploaded_audio.name)
+    res["sent_messages"] = [status, *sent_messages]
+    return res
 
 
 class Transcription(BaseModel):
src/asr/voice_recognition.py
@@ -206,23 +206,22 @@ async def asr_file(
     logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
     try:
         if asr_method == "tencent_single_asr":
-            texts = await tencent_single_asr(path, language, voice_format)
+            res["texts"] = await tencent_single_asr(path, language, voice_format)
         elif asr_method == "tencent_flash_asr":
-            texts = await tencent_flash_asr(path, language, voice_format)
+            res["texts"] = await tencent_flash_asr(path, language, voice_format)
         elif asr_method == "tencent_async_asr":
-            texts = await tencent_async_asr(path, language)
+            res["texts"] = await tencent_async_asr(path, language)
         elif asr_method == "ali":
-            texts = await ali_asr(path)
+            res["texts"] = await ali_asr(path)
         elif asr_method == "deepgram":
-            texts = await deepgram_asr(path)
+            res["texts"] = await deepgram_asr(path)
         elif asr_method == "gemini":
-            return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
-        res["texts"] = texts
-        logger.success(f"{texts!r}")
+            res |= await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
+        logger.success(f"{res['texts']!r}")
     except Exception as e:
         error = f"Failed to recognize audio: {e}"
         logger.error(error)
-        res["error"] = error
+        res["error"] = res.get("error", error)
     finally:
         path.unlink(missing_ok=True)
     return res
src/llm/gemini.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
 import contextlib
 import json
 from io import BytesIO
@@ -17,7 +16,7 @@ from pyrogram.types import Message, ReplyParameters
 from config import CAPTION_LENGTH, DOWNLOAD_DIR, GEMINI, GPT, PREFIX, TEXT_LENGTH
 from llm.contexts import get_conversation_contexts
 from llm.hooks import hook_gemini_httpoptions
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_cmd_prefix, clean_gemini_sourcemarks, clean_source_marks
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_cmd_prefix, clean_gemini_sourcemarks, clean_source_marks, shuffle_keys
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -85,23 +84,35 @@ async def gemini_stream(
     message: Message,
     model_name: str,
     params: dict,
+    prefix: str | None = None,
     retry: int = 0,
+    max_retry: int | None = None,
+    last_error: str = "",
+    *,
+    silent: bool = False,
     **kwargs,
-):
-    prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
-    answers = ""
+) -> dict:
+    if prefix is None:
+        prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
+    answers = ""  # all model responses
+    runtime_texts = ""  # for a single telegram message
+    init_status_msg = None if silent else kwargs.get("progress")
+    status_msg = init_status_msg
+    status_mid = status_msg.id if isinstance(status_msg, Message) else message.id
+    if not kwargs.get("gemini_api_keys"):
+        kwargs["gemini_api_keys"] = shuffle_keys(GEMINI.API_KEYS)
+    api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
+    max_retry = len(api_keys) - 1 if max_retry is None else max_retry
     try:
-        status: Message = kwargs.get("progress")  # type: ignore
-        api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
-        if kwargs.get("gemini_api_keys"):
-            api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
-        if retry > len(api_keys) - 1:
-            return None
+        if retry > min(len(api_keys) - 1, max_retry):
+            logger.error(f"[Gemini] Failed after {retry} retries")
+            await modify_progress(message=init_status_msg, text=last_error, force_update=True)
+            return {"error": last_error}
         api_key = kwargs.get("gemini_api_key", api_keys[retry])
         http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
-        runtime_texts = ""
+        sent_messages = []
         async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump())
             answer = resp.get("texts", "")
@@ -111,36 +122,42 @@ async def gemini_stream(
             length = await count_without_entities(prefix + runtime_texts)
             if length <= TEXT_LENGTH:
                 if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
-                    await modify_progress(message=status, text=prefix + runtime_texts, detail_progress=True)
+                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
             else:  # answers is too long, split it into multiple messages
                 parts = await smart_split(prefix + runtime_texts)
                 if len(parts) == 1:
                     continue
-                await modify_progress(message=status, text=blockquote(parts[0]), force_update=True)  # force send the first part
+                await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
                 runtime_texts = parts[-1]  # keep the last part
-                status = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status.id))  # the new message
+                if not silent:
+                    status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
+                    sent_messages.append(status_msg)
+                    status_mid = status_msg.id
 
         # all chunks are processed
         if not answers.strip():  # empty response
-            return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
+            return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, **kwargs)
 
         if await count_without_entities(prefix + answers) <= TEXT_LENGTH:  # short answer in single msg
             if length > GPT.COLLAPSE_LENGTH:  # collapse the response if the answer is too long
-                await modify_progress(message=status, text=f"{prefix}{blockquote(runtime_texts)}", force_update=True)
+                await modify_progress(message=status_msg, text=f"{prefix}{blockquote(runtime_texts)}", force_update=True)
             else:
-                await modify_progress(message=status, text=f"{prefix}{runtime_texts}", force_update=True)
+                await modify_progress(message=status_msg, text=f"{prefix}{runtime_texts}", force_update=True)
         elif length > GPT.COLLAPSE_LENGTH:
-            await modify_progress(message=status, text=prefix + blockquote(runtime_texts), force_update=True)
+            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
         else:
-            await modify_progress(message=status, text=prefix + runtime_texts, force_update=True)
+            await modify_progress(message=status_msg, text=prefix + runtime_texts, force_update=True)
 
     except Exception as e:
-        logger.error(e)
         error = str(e)
         if "resp" in locals():
             error += f"\n{resp}"
-        await modify_progress(text=error, force_update=True, **kwargs)
-        return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
+        logger.error(error)
+        with contextlib.suppress(Exception):
+            await modify_progress(message=init_status_msg, text=error, force_update=True)
+            [await modify_progress(msg, del_status=True) for msg in sent_messages]
+        return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, **kwargs)
+    return {"texts": answers, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
 
 
 async def gemini_nonstream(
src/llm/utils.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import random
 import re
 import tempfile
 from pathlib import Path
@@ -242,3 +243,15 @@ def split_reasoning(text: str) -> tuple[str, str]:
     if matched := re.search(rf"{REASONING_BEGIN}(.*?){REASONING_END}", text, flags=re.DOTALL):
         reasoning = REASONING_BEGIN + matched.group(1) + REASONING_END
     return reasoning.strip(), content.strip()
+
+
+def shuffle_keys(keys: str | list[str]) -> str:
+    """Shuffle comma speparated string."""
+    if isinstance(keys, str):
+        keys = [x.strip() for x in keys.split(",") if x.strip()]
+    elif isinstance(keys, list):
+        keys = [x.strip() for x in keys if x.strip()]
+    else:
+        return ""
+    random.shuffle(keys)
+    return ",".join(keys)