Commit 6490751

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-01-22 01:34:13
chore(asr): improve ASR
1 parent 6899887
Changed files (3)
src/asr/gemini.py
@@ -9,7 +9,6 @@ from glom import glom
 from google import genai
 from google.genai.types import File, GenerateContentConfig, HttpOptions, UploadFileConfig
 from loguru import logger
-from pydantic import BaseModel, Field
 
 from ai.utils import literal_eval
 from asr.groq import merge_transcripts
@@ -17,14 +16,32 @@ from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, con
 from config import AI, ASR, DOWNLOAD_DIR, PROXY
 from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
 
+SYSTEM_PROMPT = """你是一个专业的音频分析与转录助手。你的任务是处理音频输入,并输出结构化的转录结果。
 
-class Transcription(BaseModel):
-    start: int = Field(description="start time in seconds of the sentence in the audio")
-    sentence: str = Field(description="transcription sentence with punctuation")
-    end: int = Field(description="end time in seconds of the sentence in the audio")
+请严格遵循以下步骤进行处理:
 
+1. **全局理解与总结**:
+   - 首先听取整个音频片段。
+   - 用中文生成一段简明扼要的摘要,概括对话的核心主题和结论。
 
-async def gemini_asr(path: str | Path, prompt: str = "请转录这段音频", *, delete_gemini_file: bool = True) -> dict:
+2. **说话人区分 (Diarization)**:
+   - 准确区分不同的说话人。使用 "Speaker 1", "Speaker 2" 或根据内容推断的角色名(如 "Interviewer", "Guest")。
+   - 保持同一说话人在整个对话中标签的一致性。
+
+3. **高精度转录**:
+   - 逐句转录内容,包含正确的标点符号。
+   - 时间戳:
+     - 每个句子的开始时间(start)和结束时间(end),单位为秒。
+     - 确保 start < end,且不同句子的时间戳尽量不要重叠。
+
+4. **噪音处理**:
+   - 忽略背景音乐、背景噪音、填充词(如 uh, um)和非语言声音,只保留有意义的对话内容。
+
+请输出符合 Schema 的 JSON 输出。
+"""  # noqa: RUF001
+
+
+async def gemini_asr(path: str | Path, prompt: str = "", *, delete_gemini_file: bool = True) -> dict:
     """Gemini stream ASR.
 
     https://ai.google.dev/gemini-api/docs/audio
@@ -35,6 +52,7 @@ async def gemini_asr(path: str | Path, prompt: str = "请转录这段音频", *,
     path = Path(path).expanduser().resolve()
     if not path.is_file():
         return {"texts": "", "error": "File not found."}
+    prompt = prompt or "请转录这段音频"
     audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
     audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
     duration = audio_duration(audio_path)
@@ -78,7 +96,26 @@ async def gemini_single_file(
             genconfig = {}
             if ASR.GEMINI_CONFIG:
                 genconfig |= literal_eval(ASR.GEMINI_CONFIG)
-            genconfig |= {"response_mime_type": "application/json", "response_schema": list[Transcription]}
+            genconfig |= {
+                "system_instruction": SYSTEM_PROMPT,
+                "response_mime_type": "application/json",
+                "response_schema": {
+                    "title": "Audio Transcription",
+                    "type": "array",
+                    "description": "List of transcribed segments with speaker and timestamp.",
+                    "items": {
+                        "type": "object",
+                        "title": "Transcription Segment",
+                        "properties": {
+                            "speaker": {"type": "string", "description": "Speaker label (e.g., 'Speaker 1', 'Interviewer')"},
+                            "start": {"type": "integer", "description": "Start time in seconds of the segment in the audio"},
+                            "end": {"type": "integer", "description": "End time in seconds of the segment in the audio"},
+                            "content": {"type": "string", "description": "Verbatim transcription text with punctuation"},
+                        },
+                        "required": ["speaker", "start", "end", "content"],
+                    },
+                },
+            }
             contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
             params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
             answers = ""  # all model responses
@@ -88,7 +125,7 @@ async def gemini_single_file(
                 answers += text
             await app.aio.aclose()
             try:
-                transcriptions = json.loads(answers)
+                transcriptions = post_process_transcription(json.loads(answers))
             except json.JSONDecodeError as e:
                 logger.error(f"{e}\n{answers}")
                 continue
@@ -96,7 +133,7 @@ async def gemini_single_file(
                 {
                     "start": start_seconds + seg["start"],
                     "end": start_seconds + seg["end"],
-                    "text": zhcn(seg["sentence"]),
+                    "text": zhcn(seg["content"]),
                 }
                 for seg in transcriptions
             ]
@@ -177,3 +214,37 @@ async def gemini_file_chunks(
         logger.error(e)
         return {"error": str(e)}
     return transcription
+
+
+def post_process_transcription(transcriptions: list[dict]) -> list[dict]:
+    """Post process transcription.
+
+    Args:
+        transcriptions: list[dict], {"speaker": str, "start":int, "end": int, "content": str}
+
+    """
+    emoji_map = {
+        "Speaker 1": "🔴",
+        "Speaker 2": "🔵",
+        "Speaker 3": "🟡",
+        "Speaker 4": "🟢",
+        "Speaker 5": "🟣",
+        "Speaker 6": "🟠",
+        "Speaker 7": "🟤",
+        "Speaker 8": "⚫",
+        "Speaker 9": "⚪",
+    }
+
+    speakers = {seg["speaker"] for seg in transcriptions}
+    single_speaker = len(speakers) == 1
+    for seg in transcriptions:
+        seg["speaker"] = "" if single_speaker else emoji_map.get(seg["speaker"], f"{seg['speaker']}: ")
+
+    return [
+        {
+            "start": seg["start"],
+            "end": seg["end"],
+            "content": seg["speaker"] + zhcn(seg["content"]),
+        }
+        for seg in transcriptions
+    ]
src/asr/voice_recognition.py
@@ -188,7 +188,7 @@ async def voice_to_text(
 async def asr_file(
     path: str | Path,
     engine: str = "",
-    asr_prompt: str = "请转录这段音频",
+    prompt: str = "",
     *,
     tencent_language: str = "16k_zh-PY",
     delete_local_file: bool = True,
@@ -214,11 +214,11 @@ async def asr_file(
         elif engine == "deepgram":
             res = await deepgram_asr(path)
         elif engine == "gemini":
-            res = await gemini_asr(path=path, prompt=asr_prompt, delete_gemini_file=delete_gemini_file)
+            res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
         elif engine == "cloudflare":
-            res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=kwargs.get("cf_asr_prompt"))
+            res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=prompt)
         elif engine == "groq":
-            res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=kwargs.get("groq_asr_prompt", ""))
+            res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=prompt)
         else:
             return {"error": "ASR method not supported"}
         if res.get("texts"):
src/podcast/asr.py
@@ -40,7 +40,7 @@ async def get_transcripts(
     desc = glom(entry, Coalesce("content.0.value", "summary"), default="")
     prompt = f"请转录播客栏目《{feed_title}》的一期节目的音频。\n该期节目标题: {entry['title']}\n节目时长: {readable_time(duration)}\n节目简介: {desc}"
     engine = get_asr_engine(feed_title, feed_url)
-    asr_res = await asr_file(tmp_path, asr_prompt=prompt, engine=engine, silent=True)
+    asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, silent=True)
     Path(tmp_path).unlink(missing_ok=True)
     return asr_res.get("texts", "")