Commit d4feaa6

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-28 11:43:44
feat(asr): add stream mode
1 parent d519f14
Changed files (3)
src/asr/gemini_asr.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import io
 import random
 from pathlib import Path
 
@@ -8,8 +9,15 @@ from google import genai
 from google.genai.types import GenerateContentConfig, HttpOptions, UploadFileConfig
 from loguru import logger
 from pydantic import BaseModel
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message, ReplyParameters
 
-from config import ASR
+from config import ASR, TEXT_LENGTH
+from llm.gemini import parse_response
+from llm.utils import beautify_llm_response
+from messages.progress import modify_progress
+from messages.utils import count_without_entities, smart_split
 
 
 class Transcription(BaseModel):
@@ -39,7 +47,6 @@ async def gemini_asr(path: str | Path, voice_format: str) -> str:
                 config=GenerateContentConfig(
                     response_mime_type="application/json",
                     response_schema=list[Transcription],
-                    temperature=0,
                 ),
             )
             if parsed := glom(response.model_dump(), "parsed"):
@@ -65,3 +72,69 @@ def generate_transcription(items: list[dict]) -> str:
         else:
             res += sentence
     return res.strip()
+
+
+async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, slient: bool = False, **kwargs) -> dict:
+    """Gemini stream ASR.
+
+    https://ai.google.dev/gemini-api/docs/audio
+
+    Args:
+        slient (bool, optional): If Ture, do not update the status, return all results in the end.
+    """
+    prompt = """请转录这段音频, 要求:
+    1. 以 `[mm:ss] sentence` 格式输出句子内容, 包括标点符号。其中mm:ss为此句话开始时间的分钟和秒
+    2. 请使用简体中文输出
+    3. 直接输出音频转录内容, 不要输出任何与音频内容无关的寒暄问候
+
+    输出实例:
+    [00:02] 大家好, 我是小明, 欢迎来到我的频道。
+    [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
+    """
+
+    def warp(s: str) -> str:
+        return BLOCKQUOTE_EXPANDABLE_DELIM + s + BLOCKQUOTE_EXPANDABLE_END_DELIM
+
+    path = Path(path)
+    api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
+    transcriptions = ""
+    runtime_texts = ""
+    sent_messages = []
+    if status := kwargs.get("progress"):
+        sent_messages.append(status)
+    if slient:
+        status = None
+    try:
+        logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
+        app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
+        uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+        logger.debug(uploaded_audio)
+        async for chunk in await app.aio.models.generate_content_stream(
+            model=ASR.GEMINI_MODEL,
+            contents=[prompt, uploaded_audio],
+        ):
+            resp = parse_response(chunk.model_dump())
+            sentence = resp.get("texts", "")
+            transcriptions += sentence
+            runtime_texts += sentence
+            runtime_texts = beautify_llm_response(runtime_texts)
+            if await count_without_entities(runtime_texts) <= TEXT_LENGTH:
+                if len(runtime_texts) > 5:  # start response if sentence is not empty
+                    await modify_progress(message=status, text=runtime_texts, detail_progress=True)
+            else:  # transcriptions is too long, split it into multiple messages
+                parts = await smart_split(runtime_texts)
+                await modify_progress(message=status, text=warp(parts[0]), force_update=True)  # force send the first part
+                runtime_texts = parts[-1]  # keep the last part
+                if not slient:
+                    status = await client.send_message(message.chat.id, runtime_texts)  # the new message
+                    sent_messages.append(status)
+
+        # all chunks are processed
+        await modify_progress(message=status, text=warp(beautify_llm_response(runtime_texts)), force_update=True)
+        if len(sent_messages) > 1:
+            with io.BytesIO(transcriptions.encode("utf-8")) as f:
+                await client.send_document(message.chat.id, f, file_name="语音识别结果.txt", reply_parameters=ReplyParameters(message_id=message.id))
+            [await modify_progress(msg, del_status=True) for msg in sent_messages]
+    except Exception as e:
+        logger.error(e)
+    return {"texts": transcriptions} if slient else {}
src/asr/voice_recognition.py
@@ -12,7 +12,7 @@ from pyrogram.client import Client
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 from pyrogram.types import Message
 
-from asr.gemini_asr import gemini_asr
+from asr.gemini_asr import gemini_asr, gemini_stream_asr
 from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
 from asr.utils import get_asr_method
 from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
@@ -120,7 +120,7 @@ async def voice_to_text(
     msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
     logger.info(msg)
     if kwargs.get("show_progress"):
-        res = await send2tg(client, message, texts=msg, **kwargs)
+        res = await send2tg(client, trigger_message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
 
     path: str | Path = await trigger_message.download()  # type: ignore
@@ -131,7 +131,7 @@ async def voice_to_text(
         await modify_progress(text=msg, force_update=True, **kwargs)
         return
 
-    res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language)
+    res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language, client=client, message=trigger_message, **kwargs)
     if error := res.get("error"):
         await modify_progress(text=error, force_update=True, **kwargs)
         return
@@ -161,6 +161,9 @@ async def asr_file(
     engine: str = "",
     duration: int = 0,
     language: str = "16k_zh-PY",
+    *,
+    gemini_stream_mode: bool = True,
+    **kwargs,
 ) -> dict:
     """Get ASR results of an audio file."""
     res = {}
@@ -216,6 +219,8 @@ async def asr_file(
             else:
                 texts = glom(result, "Response.Data.ErrorMsg")
                 res["error"] = texts
+        elif asr_method == "gemini" and gemini_stream_mode:
+            return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
         elif asr_method == "gemini":
             texts = await gemini_asr(path, voice_format)
         res["texts"] = texts
src/preview/ytdlp.py
@@ -240,7 +240,7 @@ async def preview_ytdlp(
                 append_transcription = False  # disable asr transcription
 
     if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
-        asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration)
+        asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, slient=True)
         if texts := asr_res.get("texts"):
             caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
             with io.BytesIO(texts.encode("utf-8")) as f: