Commit 6490751
Changed files (3)
src
podcast
src/asr/gemini.py
@@ -9,7 +9,6 @@ from glom import glom
from google import genai
from google.genai.types import File, GenerateContentConfig, HttpOptions, UploadFileConfig
from loguru import logger
-from pydantic import BaseModel, Field
from ai.utils import literal_eval
from asr.groq import merge_transcripts
@@ -17,14 +16,32 @@ from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, con
from config import AI, ASR, DOWNLOAD_DIR, PROXY
from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
+SYSTEM_PROMPT = """你是一个专业的音频分析与转录助手。你的任务是处理音频输入,并输出结构化的转录结果。
-class Transcription(BaseModel):
- start: int = Field(description="start time in seconds of the sentence in the audio")
- sentence: str = Field(description="transcription sentence with punctuation")
- end: int = Field(description="end time in seconds of the sentence in the audio")
+请严格遵循以下步骤进行处理:
+1. **全局理解与总结**:
+ - 首先听取整个音频片段。
+ - 用中文生成一段简明扼要的摘要,概括对话的核心主题和结论。
-async def gemini_asr(path: str | Path, prompt: str = "请转录这段音频", *, delete_gemini_file: bool = True) -> dict:
+2. **说话人区分 (Diarization)**:
+ - 准确区分不同的说话人。使用 "Speaker 1", "Speaker 2" 或根据内容推断的角色名(如 "Interviewer", "Guest")。
+ - 保持同一说话人在整个对话中标签的一致性。
+
+3. **高精度转录**:
+ - 逐句转录内容,包含正确的标点符号。
+ - 时间戳:
+ - 每个句子的开始时间(start)和结束时间(end),单位为秒。
+ - 确保 start < end,且不同句子的时间戳尽量不要重叠。
+
+4. **噪音处理**:
+ - 忽略背景音乐、背景噪音、填充词(如 uh, um)和非语言声音,只保留有意义的对话内容。
+
+请输出符合 Schema 的 JSON 输出。
+""" # noqa: RUF001
+
+
+async def gemini_asr(path: str | Path, prompt: str = "", *, delete_gemini_file: bool = True) -> dict:
"""Gemini stream ASR.
https://ai.google.dev/gemini-api/docs/audio
@@ -35,6 +52,7 @@ async def gemini_asr(path: str | Path, prompt: str = "请转录这段音频", *,
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"texts": "", "error": "File not found."}
+ prompt = prompt or "请转录这段音频"
audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
duration = audio_duration(audio_path)
@@ -78,7 +96,26 @@ async def gemini_single_file(
genconfig = {}
if ASR.GEMINI_CONFIG:
genconfig |= literal_eval(ASR.GEMINI_CONFIG)
- genconfig |= {"response_mime_type": "application/json", "response_schema": list[Transcription]}
+ genconfig |= {
+ "system_instruction": SYSTEM_PROMPT,
+ "response_mime_type": "application/json",
+ "response_schema": {
+ "title": "Audio Transcription",
+ "type": "array",
+ "description": "List of transcribed segments with speaker and timestamp.",
+ "items": {
+ "type": "object",
+ "title": "Transcription Segment",
+ "properties": {
+ "speaker": {"type": "string", "description": "Speaker label (e.g., 'Speaker 1', 'Interviewer')"},
+ "start": {"type": "integer", "description": "Start time in seconds of the segment in the audio"},
+ "end": {"type": "integer", "description": "End time in seconds of the segment in the audio"},
+ "content": {"type": "string", "description": "Verbatim transcription text with punctuation"},
+ },
+ "required": ["speaker", "start", "end", "content"],
+ },
+ },
+ }
contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
answers = "" # all model responses
@@ -88,7 +125,7 @@ async def gemini_single_file(
answers += text
await app.aio.aclose()
try:
- transcriptions = json.loads(answers)
+ transcriptions = post_process_transcription(json.loads(answers))
except json.JSONDecodeError as e:
logger.error(f"{e}\n{answers}")
continue
@@ -96,7 +133,7 @@ async def gemini_single_file(
{
"start": start_seconds + seg["start"],
"end": start_seconds + seg["end"],
- "text": zhcn(seg["sentence"]),
+ "text": zhcn(seg["content"]),
}
for seg in transcriptions
]
@@ -177,3 +214,37 @@ async def gemini_file_chunks(
logger.error(e)
return {"error": str(e)}
return transcription
+
+
+def post_process_transcription(transcriptions: list[dict]) -> list[dict]:
+ """Post process transcription.
+
+ Args:
+ transcriptions: list[dict], {"speaker": str, "start":int, "end": int, "content": str}
+
+ """
+ emoji_map = {
+ "Speaker 1": "🔴",
+ "Speaker 2": "🔵",
+ "Speaker 3": "🟡",
+ "Speaker 4": "🟢",
+ "Speaker 5": "🟣",
+ "Speaker 6": "🟠",
+ "Speaker 7": "🟤",
+ "Speaker 8": "⚫",
+ "Speaker 9": "⚪",
+ }
+
+ speakers = {seg["speaker"] for seg in transcriptions}
+ single_speaker = len(speakers) == 1
+ for seg in transcriptions:
+ seg["speaker"] = "" if single_speaker else emoji_map.get(seg["speaker"], f"{seg['speaker']}: ")
+
+ return [
+ {
+ "start": seg["start"],
+ "end": seg["end"],
+ "content": seg["speaker"] + zhcn(seg["content"]),
+ }
+ for seg in transcriptions
+ ]
src/asr/voice_recognition.py
@@ -188,7 +188,7 @@ async def voice_to_text(
async def asr_file(
path: str | Path,
engine: str = "",
- asr_prompt: str = "请转录这段音频",
+ prompt: str = "",
*,
tencent_language: str = "16k_zh-PY",
delete_local_file: bool = True,
@@ -214,11 +214,11 @@ async def asr_file(
elif engine == "deepgram":
res = await deepgram_asr(path)
elif engine == "gemini":
- res = await gemini_asr(path=path, prompt=asr_prompt, delete_gemini_file=delete_gemini_file)
+ res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
elif engine == "cloudflare":
- res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=kwargs.get("cf_asr_prompt"))
+ res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=prompt)
elif engine == "groq":
- res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=kwargs.get("groq_asr_prompt", ""))
+ res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=prompt)
else:
return {"error": "ASR method not supported"}
if res.get("texts"):
src/podcast/asr.py
@@ -40,7 +40,7 @@ async def get_transcripts(
desc = glom(entry, Coalesce("content.0.value", "summary"), default="")
prompt = f"请转录播客栏目《{feed_title}》的一期节目的音频。\n该期节目标题: {entry['title']}\n节目时长: {readable_time(duration)}\n节目简介: {desc}"
engine = get_asr_engine(feed_title, feed_url)
- asr_res = await asr_file(tmp_path, asr_prompt=prompt, engine=engine, silent=True)
+ asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, silent=True)
Path(tmp_path).unlink(missing_ok=True)
return asr_res.get("texts", "")