Commit 9a33290

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-12 11:12:28
feat(asr): add retry mechanism for Gemini
1 parent 254f6e8
Changed files (1)
src/asr/gemini_asr.py
@@ -21,7 +21,19 @@ from messages.progress import modify_progress
 from messages.utils import blockquote, count_without_entities, smart_split
 
 
-async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频", slient: bool = False, **kwargs) -> dict:
+async def gemini_stream_asr(
+    client: Client,
+    message: Message,
+    path: str | Path,
+    voice_format: str,
+    prompt: str = "请转录这段音频",
+    *,
+    slient: bool = False,
+    retry: int = 0,
+    max_retry: int = 2,
+    last_error: str = "",
+    **kwargs,
+) -> dict:
     """Gemini stream ASR.
 
     https://ai.google.dev/gemini-api/docs/audio
@@ -59,7 +71,9 @@ Example-2:
 Notes:
 - Focus on accuracy in capturing both the timing and the spoken content.
 - Maintain consistent formatting to ensure clarity and readability."""
-
+    if retry > max_retry:
+        logger.error(f"[GeminiASR] Failed after {retry} retries")
+        return {"error": last_error}
     path = Path(path)
     api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
     transcriptions = ""
@@ -109,7 +123,11 @@ Notes:
             await app.aio.files.delete(name=uploaded_audio.name)
     except Exception as e:
         logger.error(e)
-        return {"error": str(e)}
+        with contextlib.suppress(Exception):
+            [await modify_progress(msg, del_status=True) for msg in sent_messages]
+            if "uploaded_audio" in locals() and uploaded_audio.name:
+                await app.aio.files.delete(name=uploaded_audio.name)
+        return await gemini_stream_asr(client, message, path, voice_format, prompt, slient=slient, retry=retry + 1, max_retry=max_retry, last_error=str(e))
     return {"texts": transcriptions, "sent_messages": sent_messages}