main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import contextlib
5import json
6from pathlib import Path
7
8from glom import glom
9from google import genai
10from google.genai.types import File, GenerateContentConfig, HttpOptions, UploadFileConfig
11from loguru import logger
12
13from ai.utils import literal_eval
14from asr.groq import merge_transcripts
15from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, convert_single_channel, downsampe_audio, load_audio
16from config import AI, ASR, DOWNLOAD_DIR, PROXY
17from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
18
19SYSTEM_PROMPT = """你是一个专业的音频分析与转录助手。你的任务是处理音频输入,并输出结构化的转录结果。
20
21请严格遵循以下步骤进行处理:
22
231. **全局理解与总结**:
24 - 首先听取整个音频片段。
25 - 用中文生成一段简明扼要的摘要,概括对话的核心主题和结论。
26
272. **说话人区分 (Diarization)**:
28 - 准确区分不同的说话人。使用 "Speaker 1", "Speaker 2" 或根据内容推断的角色名(如 "Interviewer", "Guest")。
29 - 保持同一说话人在整个对话中标签的一致性。
30
313. **高精度转录**:
32 - 逐句转录内容,包含正确的标点符号。
33 - 时间戳:
34 - 每个句子的开始时间(start)和结束时间(end),单位为秒。
35 - 确保 start < end,且不同句子的时间戳尽量不要重叠。
36
374. **噪音处理**:
38 - 忽略背景音乐、背景噪音、填充词(如 uh, um)和非语言声音,只保留有意义的对话内容。
39
40请输出符合 Schema 的 JSON 输出。
41""" # noqa: RUF001
42
43
44async def gemini_asr(path: str | Path, prompt: str = "", *, delete_gemini_file: bool = True) -> dict:
45 """Gemini stream ASR.
46
47 https://ai.google.dev/gemini-api/docs/audio
48
49 Args:
50 silent (bool, optional): If Ture, do not update the status, return all results in the end.
51 """
52 path = Path(path).expanduser().resolve()
53 if not path.is_file():
54 return {"texts": "", "error": "File not found."}
55 prompt = prompt or "请转录这段音频"
56 audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
57 audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
58 duration = audio_duration(audio_path)
59 if duration < ASR.GEMINI_CHUNK_SECONDS:
60 return await gemini_single_file(audio_path, prompt=prompt, delete_gemini_file=delete_gemini_file)
61 return await gemini_file_chunks(audio_path, prompt=prompt, delete_gemini_file=delete_gemini_file)
62
63
64async def gemini_single_file(
65 path: str | Path,
66 prompt: str = "",
67 *,
68 start_seconds: int = 0,
69 delete_local_file: bool = False,
70 delete_gemini_file: bool = True,
71) -> dict:
72 """Gemini stream ASR.
73
74 https://ai.google.dev/gemini-api/docs/audio
75
76 Returns:
77 {"texts": str, "raw_texts": str, "segments": list[dict]}
78 """
79 path = Path(path).expanduser().resolve()
80 if not path.is_file():
81 return {"texts": "", "raw_texts": "", "segments": [], "error": "File not found."}
82 res = {}
83 for api_key in strings_list(AI.GEMINI_API_KEYS, shuffle=True):
84 logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={PROXY.GOOGLE}")
85 app = genai.Client(
86 api_key=api_key,
87 http_options=HttpOptions(
88 base_url=AI.GEMINI_BASE_URL,
89 headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
90 async_client_args={"proxy": PROXY.GOOGLE},
91 ),
92 )
93 uploaded_audio = File()
94 try:
95 uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=guess_mime(path)))
96 genconfig = {}
97 if ASR.GEMINI_CONFIG:
98 genconfig |= literal_eval(ASR.GEMINI_CONFIG)
99 genconfig |= {
100 "system_instruction": SYSTEM_PROMPT,
101 "response_mime_type": "application/json",
102 "response_schema": {
103 "title": "Audio Transcription",
104 "type": "array",
105 "description": "List of transcribed segments with speaker and timestamp.",
106 "items": {
107 "type": "object",
108 "title": "Transcription Segment",
109 "properties": {
110 "speaker": {"type": "string", "description": "Speaker label (e.g., 'Speaker 1', 'Interviewer')"},
111 "start": {"type": "integer", "description": "Start time in seconds of the segment in the audio"},
112 "end": {"type": "integer", "description": "End time in seconds of the segment in the audio"},
113 "content": {"type": "string", "description": "Verbatim transcription text with punctuation"},
114 },
115 "required": ["speaker", "start", "end", "content"],
116 },
117 },
118 }
119 contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
120 params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
121 answers = "" # all model responses
122 async for chunk in await app.aio.models.generate_content_stream(**params):
123 text = glom(chunk.model_dump(), "candidates.0.content.parts.0.text", default="") or ""
124 logger.trace(f"{text!r}")
125 answers += text
126 await app.aio.aclose()
127 try:
128 transcriptions = post_process_transcription(json.loads(answers))
129 except json.JSONDecodeError as e:
130 logger.error(f"{e}\n{answers}")
131 continue
132 res["segments"] = [
133 {
134 "start": start_seconds + seg["start"],
135 "end": start_seconds + seg["end"],
136 "text": zhcn(seg["content"]),
137 }
138 for seg in transcriptions
139 ]
140 res["raw_texts"] = " ".join(x["text"] for x in res["segments"])
141 res["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in res["segments"]) # with timestamp
142 break
143 except Exception as e:
144 logger.error(e)
145 finally:
146 if delete_local_file:
147 path.unlink(missing_ok=True)
148 with contextlib.suppress(Exception):
149 if uploaded_audio.name:
150 if delete_gemini_file:
151 await app.aio.files.delete(name=uploaded_audio.name)
152 else:
153 res["gemini_file"] = uploaded_audio
154 return res
155
156
157async def gemini_file_chunks(
158 path: str | Path,
159 chunk_seconds: float = 600,
160 overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
161 prompt: str = "",
162 *,
163 delete_gemini_file: bool = True,
164) -> dict:
165 """Transcribe audio in chunks with overlap.
166
167 Args:
168 path: Path to audio file
169 chunk_seconds: Length of each chunk in seconds
170 overlap_seconds: Overlap between chunks in seconds
171
172 Returns:
173 dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
174 """
175 path = Path(path).expanduser().resolve()
176 wav_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
177 audio, duration, sr = load_audio(wav_path)
178 if sr == 0:
179 return {"error": "Failed to load audio."}
180 transcription = {}
181 try:
182 # Calculate # of chunks
183 total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
184 total_chunks = int(total_chunks)
185 chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.wav" for _ in range(total_chunks)]
186 tasks = []
187 offset_list = []
188 # Loop through each chunk, extract current chunk from audio
189 for i in range(total_chunks):
190 start = int(i * (chunk_seconds - overlap_seconds) * sr)
191 end = int(min(start + chunk_seconds * sr, duration * sr))
192 logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
193 chunk = audio[start:end]
194 if chunk.shape[0] == 0: # empty chunk
195 continue
196 tasks.append(audio_chunk_to_path(chunk, sr, chunk_paths[i]))
197 offset_list.append(int(start / sr))
198 await asyncio.gather(*tasks) # convert chunks to paths
199 tasks = [
200 gemini_single_file(
201 audio_path,
202 prompt=prompt,
203 start_seconds=offset,
204 delete_local_file=False,
205 delete_gemini_file=delete_gemini_file,
206 )
207 for audio_path, offset in zip(chunk_paths, offset_list, strict=True)
208 ]
209 results = await asyncio.gather(*tasks)
210 results = [r for r in results if r.get("segments")]
211 transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
212 [p.unlink(missing_ok=True) for p in chunk_paths]
213 except Exception as e:
214 logger.error(e)
215 return {"error": str(e)}
216 return transcription
217
218
219def post_process_transcription(transcriptions: list[dict]) -> list[dict]:
220 """Post process transcription.
221
222 Args:
223 transcriptions: list[dict], {"speaker": str, "start":int, "end": int, "content": str}
224
225 """
226 emoji_map = {
227 "Speaker 1": "🔴",
228 "Speaker 2": "🔵",
229 "Speaker 3": "🟡",
230 "Speaker 4": "🟢",
231 "Speaker 5": "🟣",
232 "Speaker 6": "🟠",
233 "Speaker 7": "🟤",
234 "Speaker 8": "⚫",
235 "Speaker 9": "⚪",
236 }
237
238 speakers = {seg["speaker"] for seg in transcriptions}
239 single_speaker = len(speakers) == 1
240 for seg in transcriptions:
241 seg["speaker"] = "" if single_speaker else emoji_map.get(seg["speaker"], f"{seg['speaker']}: ")
242
243 return [
244 {
245 "start": seg["start"],
246 "end": seg["end"],
247 "content": seg["speaker"] + zhcn(seg["content"]),
248 }
249 for seg in transcriptions
250 ]