Commit a88e54f
Changed files (1)
src
src/asr/gemini_asr.py
@@ -21,7 +21,7 @@ from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, smart_split
-async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, slient: bool = False, **kwargs) -> dict:
+async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频", slient: bool = False, **kwargs) -> dict:
"""Gemini stream ASR.
https://ai.google.dev/gemini-api/docs/audio
@@ -84,7 +84,8 @@ Notes:
if ASR.GEMINI_THINKING_BUDGET is not None:
thinking_budget = min(round(float(ASR.GEMINI_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
- params = {"model": ASR.GEMINI_MODEL, "contents": ["请转录这段音频", uploaded_audio], "config": GenerateContentConfig(**genconfig)}
+ contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+ params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
async for chunk in await app.aio.models.generate_content_stream(**params):
resp = parse_response(chunk.model_dump())
sentence = resp.get("texts", "")
@@ -134,7 +135,7 @@ def generate_transcription(items: list[dict]) -> str:
return res.strip()
-async def gemini_nonstream_asr(path: str | Path, voice_format: str) -> str:
+async def gemini_nonstream_asr(path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频") -> str:
"""(Deprecated) Gemini ASR.
This function is deprecated and will be removed in the future.
@@ -152,9 +153,10 @@ async def gemini_nonstream_asr(path: str | Path, voice_format: str) -> str:
client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
logger.debug(uploaded_audio)
+ contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
response = await client.aio.models.generate_content(
model=ASR.GEMINI_MODEL,
- contents=["请转录这段音频", uploaded_audio],
+ contents=contents, # type: ignore
config=GenerateContentConfig(
response_mime_type="application/json",
response_schema=list[Transcription],