Commit a88e54f

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-07 14:08:08
feat(asr): support custom Gemini ASR prompt
1 parent bf0294a
Changed files (1)
src/asr/gemini_asr.py
@@ -21,7 +21,7 @@ from messages.progress import modify_progress
 from messages.utils import blockquote, count_without_entities, smart_split
 
 
-async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, slient: bool = False, **kwargs) -> dict:
+async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频", slient: bool = False, **kwargs) -> dict:
     """Gemini stream ASR.
 
     https://ai.google.dev/gemini-api/docs/audio
@@ -84,7 +84,8 @@ Notes:
         if ASR.GEMINI_THINKING_BUDGET is not None:
             thinking_budget = min(round(float(ASR.GEMINI_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
             genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
-        params = {"model": ASR.GEMINI_MODEL, "contents": ["请转录这段音频", uploaded_audio], "config": GenerateContentConfig(**genconfig)}
+        contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+        params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
         async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump())
             sentence = resp.get("texts", "")
@@ -134,7 +135,7 @@ def generate_transcription(items: list[dict]) -> str:
     return res.strip()
 
 
-async def gemini_nonstream_asr(path: str | Path, voice_format: str) -> str:
+async def gemini_nonstream_asr(path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频") -> str:
     """(Deprecated) Gemini ASR.
 
     This function is deprecated and will be removed in the future.
@@ -152,9 +153,10 @@ async def gemini_nonstream_asr(path: str | Path, voice_format: str) -> str:
             client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
             uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
             logger.debug(uploaded_audio)
+            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
             response = await client.aio.models.generate_content(
                 model=ASR.GEMINI_MODEL,
-                contents=["请转录这段音频", uploaded_audio],
+                contents=contents,  # type: ignore
                 config=GenerateContentConfig(
                     response_mime_type="application/json",
                     response_schema=list[Transcription],