Commit 9a33290
Changed files (1)
src
src/asr/gemini_asr.py
@@ -21,7 +21,19 @@ from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, smart_split
-async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频", slient: bool = False, **kwargs) -> dict:
+async def gemini_stream_asr(
+ client: Client,
+ message: Message,
+ path: str | Path,
+ voice_format: str,
+ prompt: str = "请转录这段音频",
+ *,
+ slient: bool = False,
+ retry: int = 0,
+ max_retry: int = 2,
+ last_error: str = "",
+ **kwargs,
+) -> dict:
"""Gemini stream ASR.
https://ai.google.dev/gemini-api/docs/audio
@@ -59,7 +71,9 @@ Example-2:
Notes:
- Focus on accuracy in capturing both the timing and the spoken content.
- Maintain consistent formatting to ensure clarity and readability."""
-
+ if retry > max_retry:
+ logger.error(f"[GeminiASR] Failed after {retry} retries")
+ return {"error": last_error}
path = Path(path)
api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
transcriptions = ""
@@ -109,7 +123,11 @@ Notes:
await app.aio.files.delete(name=uploaded_audio.name)
except Exception as e:
logger.error(e)
- return {"error": str(e)}
+ with contextlib.suppress(Exception):
+ [await modify_progress(msg, del_status=True) for msg in sent_messages]
+ if "uploaded_audio" in locals() and uploaded_audio.name:
+ await app.aio.files.delete(name=uploaded_audio.name)
+ return await gemini_stream_asr(client, message, path, voice_format, prompt, slient=slient, retry=retry + 1, max_retry=max_retry, last_error=str(e))
return {"texts": transcriptions, "sent_messages": sent_messages}