main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import contextlib
  5import json
  6from pathlib import Path
  7
  8from glom import glom
  9from google import genai
 10from google.genai.types import File, GenerateContentConfig, HttpOptions, UploadFileConfig
 11from loguru import logger
 12
 13from ai.utils import literal_eval
 14from asr.groq import merge_transcripts
 15from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, convert_single_channel, downsampe_audio, load_audio
 16from config import AI, ASR, DOWNLOAD_DIR, PROXY
 17from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
 18
 19SYSTEM_PROMPT = """你是一个专业的音频分析与转录助手。你的任务是处理音频输入,并输出结构化的转录结果。
 20
 21请严格遵循以下步骤进行处理:
 22
 231. **全局理解与总结**:
 24   - 首先听取整个音频片段。
 25   - 用中文生成一段简明扼要的摘要,概括对话的核心主题和结论。
 26
 272. **说话人区分 (Diarization)**:
 28   - 准确区分不同的说话人。使用 "Speaker 1", "Speaker 2" 或根据内容推断的角色名(如 "Interviewer", "Guest")。
 29   - 保持同一说话人在整个对话中标签的一致性。
 30
 313. **高精度转录**:
 32   - 逐句转录内容,包含正确的标点符号。
 33   - 时间戳:
 34     - 每个句子的开始时间(start)和结束时间(end),单位为秒。
 35     - 确保 start < end,且不同句子的时间戳尽量不要重叠。
 36
 374. **噪音处理**:
 38   - 忽略背景音乐、背景噪音、填充词(如 uh, um)和非语言声音,只保留有意义的对话内容。
 39
 40请输出符合 Schema 的 JSON 输出。
 41"""  # noqa: RUF001
 42
 43
 44async def gemini_asr(path: str | Path, prompt: str = "", *, delete_gemini_file: bool = True) -> dict:
 45    """Gemini stream ASR.
 46
 47    https://ai.google.dev/gemini-api/docs/audio
 48
 49    Args:
 50        silent (bool, optional): If Ture, do not update the status, return all results in the end.
 51    """
 52    path = Path(path).expanduser().resolve()
 53    if not path.is_file():
 54        return {"texts": "", "error": "File not found."}
 55    prompt = prompt or "请转录这段音频"
 56    audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 57    audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
 58    duration = audio_duration(audio_path)
 59    if duration < ASR.GEMINI_CHUNK_SECONDS:
 60        return await gemini_single_file(audio_path, prompt=prompt, delete_gemini_file=delete_gemini_file)
 61    return await gemini_file_chunks(audio_path, prompt=prompt, delete_gemini_file=delete_gemini_file)
 62
 63
 64async def gemini_single_file(
 65    path: str | Path,
 66    prompt: str = "",
 67    *,
 68    start_seconds: int = 0,
 69    delete_local_file: bool = False,
 70    delete_gemini_file: bool = True,
 71) -> dict:
 72    """Gemini stream ASR.
 73
 74    https://ai.google.dev/gemini-api/docs/audio
 75
 76    Returns:
 77        {"texts": str, "raw_texts": str, "segments": list[dict]}
 78    """
 79    path = Path(path).expanduser().resolve()
 80    if not path.is_file():
 81        return {"texts": "", "raw_texts": "", "segments": [], "error": "File not found."}
 82    res = {}
 83    for api_key in strings_list(AI.GEMINI_API_KEYS, shuffle=True):
 84        logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={PROXY.GOOGLE}")
 85        app = genai.Client(
 86            api_key=api_key,
 87            http_options=HttpOptions(
 88                base_url=AI.GEMINI_BASE_URL,
 89                headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
 90                async_client_args={"proxy": PROXY.GOOGLE},
 91            ),
 92        )
 93        uploaded_audio = File()
 94        try:
 95            uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=guess_mime(path)))
 96            genconfig = {}
 97            if ASR.GEMINI_CONFIG:
 98                genconfig |= literal_eval(ASR.GEMINI_CONFIG)
 99            genconfig |= {
100                "system_instruction": SYSTEM_PROMPT,
101                "response_mime_type": "application/json",
102                "response_schema": {
103                    "title": "Audio Transcription",
104                    "type": "array",
105                    "description": "List of transcribed segments with speaker and timestamp.",
106                    "items": {
107                        "type": "object",
108                        "title": "Transcription Segment",
109                        "properties": {
110                            "speaker": {"type": "string", "description": "Speaker label (e.g., 'Speaker 1', 'Interviewer')"},
111                            "start": {"type": "integer", "description": "Start time in seconds of the segment in the audio"},
112                            "end": {"type": "integer", "description": "End time in seconds of the segment in the audio"},
113                            "content": {"type": "string", "description": "Verbatim transcription text with punctuation"},
114                        },
115                        "required": ["speaker", "start", "end", "content"],
116                    },
117                },
118            }
119            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
120            params = {"model": ASR.GEMINI_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
121            answers = ""  # all model responses
122            async for chunk in await app.aio.models.generate_content_stream(**params):
123                text = glom(chunk.model_dump(), "candidates.0.content.parts.0.text", default="") or ""
124                logger.trace(f"{text!r}")
125                answers += text
126            await app.aio.aclose()
127            try:
128                transcriptions = post_process_transcription(json.loads(answers))
129            except json.JSONDecodeError as e:
130                logger.error(f"{e}\n{answers}")
131                continue
132            res["segments"] = [
133                {
134                    "start": start_seconds + seg["start"],
135                    "end": start_seconds + seg["end"],
136                    "text": zhcn(seg["content"]),
137                }
138                for seg in transcriptions
139            ]
140            res["raw_texts"] = " ".join(x["text"] for x in res["segments"])
141            res["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in res["segments"])  # with timestamp
142            break
143        except Exception as e:
144            logger.error(e)
145        finally:
146            if delete_local_file:
147                path.unlink(missing_ok=True)
148            with contextlib.suppress(Exception):
149                if uploaded_audio.name:
150                    if delete_gemini_file:
151                        await app.aio.files.delete(name=uploaded_audio.name)
152                    else:
153                        res["gemini_file"] = uploaded_audio
154    return res
155
156
157async def gemini_file_chunks(
158    path: str | Path,
159    chunk_seconds: float = 600,
160    overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
161    prompt: str = "",
162    *,
163    delete_gemini_file: bool = True,
164) -> dict:
165    """Transcribe audio in chunks with overlap.
166
167    Args:
168        path: Path to audio file
169        chunk_seconds: Length of each chunk in seconds
170        overlap_seconds: Overlap between chunks in seconds
171
172    Returns:
173        dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
174    """
175    path = Path(path).expanduser().resolve()
176    wav_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
177    audio, duration, sr = load_audio(wav_path)
178    if sr == 0:
179        return {"error": "Failed to load audio."}
180    transcription = {}
181    try:
182        # Calculate # of chunks
183        total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
184        total_chunks = int(total_chunks)
185        chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.wav" for _ in range(total_chunks)]
186        tasks = []
187        offset_list = []
188        # Loop through each chunk, extract current chunk from audio
189        for i in range(total_chunks):
190            start = int(i * (chunk_seconds - overlap_seconds) * sr)
191            end = int(min(start + chunk_seconds * sr, duration * sr))
192            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
193            chunk = audio[start:end]
194            if chunk.shape[0] == 0:  # empty chunk
195                continue
196            tasks.append(audio_chunk_to_path(chunk, sr, chunk_paths[i]))
197            offset_list.append(int(start / sr))
198        await asyncio.gather(*tasks)  # convert chunks to paths
199        tasks = [
200            gemini_single_file(
201                audio_path,
202                prompt=prompt,
203                start_seconds=offset,
204                delete_local_file=False,
205                delete_gemini_file=delete_gemini_file,
206            )
207            for audio_path, offset in zip(chunk_paths, offset_list, strict=True)
208        ]
209        results = await asyncio.gather(*tasks)
210        results = [r for r in results if r.get("segments")]
211        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
212        [p.unlink(missing_ok=True) for p in chunk_paths]
213    except Exception as e:
214        logger.error(e)
215        return {"error": str(e)}
216    return transcription
217
218
219def post_process_transcription(transcriptions: list[dict]) -> list[dict]:
220    """Post process transcription.
221
222    Args:
223        transcriptions: list[dict], {"speaker": str, "start":int, "end": int, "content": str}
224
225    """
226    emoji_map = {
227        "Speaker 1": "🔴",
228        "Speaker 2": "🔵",
229        "Speaker 3": "🟡",
230        "Speaker 4": "🟢",
231        "Speaker 5": "🟣",
232        "Speaker 6": "🟠",
233        "Speaker 7": "🟤",
234        "Speaker 8": "",
235        "Speaker 9": "",
236    }
237
238    speakers = {seg["speaker"] for seg in transcriptions}
239    single_speaker = len(speakers) == 1
240    for seg in transcriptions:
241        seg["speaker"] = "" if single_speaker else emoji_map.get(seg["speaker"], f"{seg['speaker']}: ")
242
243    return [
244        {
245            "start": seg["start"],
246            "end": seg["end"],
247            "content": seg["speaker"] + zhcn(seg["content"]),
248        }
249        for seg in transcriptions
250    ]