Commit 2bbe68b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-30 16:53:17
perf(asr): refactor `gemini_asr` to support chunked audio transcription
1 parent 39e439b
src/asr/gemini.py
@@ -0,0 +1,299 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import contextlib
+import json
+import tempfile
+from pathlib import Path
+
+import soundfile as sf
+from glom import glom
+from google import genai
+from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig, UrlContext
+from loguru import logger
+from pydantic import BaseModel, Field
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from asr.groq import merge_transcripts
+from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
+from config import ASR, GEMINI
+from llm.gemini import gemini_stream
+from llm.hooks import hook_gemini_httpoptions
+from llm.utils import shuffle_keys
+from messages.progress import modify_progress
+from utils import count_subtitles, guess_mime, seconds_to_time, strings_list, zhcn
+
+
+class Transcription(BaseModel):
+    start: int = Field(description="start time in seconds of the sentence in the audio")
+    sentence: str = Field(description="transcription sentence with punctuation")
+    end: int = Field(description="end time in seconds of the sentence in the audio")
+
+
+async def gemini_asr(
+    message: Message,
+    path: str | Path,
+    model_id: str = "",
+    prompt: str = "请转录这段音频",
+    *,
+    delete_gemini_file: bool = True,
+) -> dict:
+    """Gemini stream ASR.
+
+    https://ai.google.dev/gemini-api/docs/audio
+
+    Args:
+        silent (bool, optional): If Ture, do not update the status, return all results in the end.
+    """
+    path = Path(path).expanduser().resolve()
+    if not path.is_file():
+        return {"texts": "", "error": "File not found."}
+    audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+    audio_path = await convert_single_channel(audio_path)
+    duration = audio_duration(audio_path)
+    if duration < ASR.GEMINI_CHUNK_SECONDS:
+        return await gemini_single_file(message, audio_path, model_id=model_id, prompt=prompt, delete_gemini_file=delete_gemini_file)
+    return await gemini_file_chunks(message, audio_path, model_id=model_id, prompt=prompt, delete_gemini_file=delete_gemini_file)
+
+
+async def gemini_single_file(
+    message: Message,
+    path: str | Path,
+    model_id: str = "",
+    prompt: str = "",
+    *,
+    start_seconds: int = 0,
+    delete_local_file: bool = False,
+    delete_gemini_file: bool = True,
+) -> dict:
+    """Gemini stream ASR.
+
+    https://ai.google.dev/gemini-api/docs/audio
+
+    Returns:
+        {"texts": str, "raw_texts": str, "segments": list[dict]}
+    """
+    path = Path(path).expanduser().resolve()
+    if not path.is_file():
+        return {"texts": "", "raw_texts": "", "segments": [], "error": "File not found."}
+    res = {}
+    if not model_id:
+        model_id = GEMINI.ASR_MODEL
+    for api_key in strings_list(GEMINI.API_KEY, shuffle=True):
+        try:
+            logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+            http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+            http_options = hook_gemini_httpoptions(http_options, message)
+            app = genai.Client(api_key=api_key, http_options=http_options)
+            uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=guess_mime(path)))
+            genconfig = {}
+            with contextlib.suppress(Exception):
+                genconfig = json.loads(GEMINI.ASR_CONFIG)
+            genconfig |= {"response_mime_type": "application/json", "response_schema": list[Transcription]}
+            if GEMINI.ASR_THINKING_BUDGET is not None:
+                thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+                genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+            params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+            answers = ""  # all model responses
+            async for chunk in await app.aio.models.generate_content_stream(**params):
+                text = glom(chunk.model_dump(), "candidates.0.content.parts.0.text", default="") or ""
+                logger.trace(f"{text!r}")
+                answers += text
+            try:
+                transcriptions = json.loads(answers)
+            except json.JSONDecodeError as e:
+                logger.error(f"{e}\n{answers}")
+                continue
+            res["segments"] = [
+                {
+                    "start": start_seconds + seg["start"],
+                    "end": start_seconds + seg["end"],
+                    "text": zhcn(seg["sentence"]),
+                }
+                for seg in transcriptions
+            ]
+            res["raw_texts"] = " ".join(x["text"] for x in res["segments"])
+            res["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in res["segments"])  # with timestamp
+            break
+        except Exception as e:
+            logger.error(e)
+        finally:
+            if delete_local_file:
+                path.unlink(missing_ok=True)
+            with contextlib.suppress(Exception):
+                if "uploaded_audio" in locals() and uploaded_audio.name:
+                    if delete_gemini_file:
+                        await app.aio.files.delete(name=uploaded_audio.name)
+                    else:
+                        res["gemini_file"] = uploaded_audio
+    return res
+
+
+async def gemini_file_chunks(
+    message: Message,
+    path: str | Path,
+    chunk_seconds: float = ASR.GEMINI_CHUNK_SECONDS,
+    overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
+    model_id: str = "",
+    prompt: str = "",
+    *,
+    delete_gemini_file: bool = True,
+) -> dict:
+    """Transcribe audio in chunks with overlap.
+
+    Args:
+        path: Path to audio file
+        chunk_seconds: Length of each chunk in seconds
+        overlap_seconds: Overlap between chunks in seconds
+
+    Returns:
+        dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
+    """
+    path = Path(path).expanduser().resolve()
+    with sf.SoundFile(path, "r") as f:
+        sr = f.samplerate
+        audio = f.read(dtype="float32")
+        duration_seconds = len(audio) / sr
+        logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
+
+    transcription = {}
+    try:
+        # Calculate # of chunks
+        total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+        total_chunks = int(total_chunks)
+        tasks = []
+        # Loop through each chunk, extract current chunk from audio, transcribe
+        for i in range(total_chunks):
+            start = int(i * (chunk_seconds - overlap_seconds) * sr)
+            end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
+            chunk = audio[start:end]
+            if chunk.shape[0] == 0:  # empty chunk
+                continue
+            with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
+                sf.write(chunk_file.name, chunk, sr, format="ogg", subtype="OPUS")
+                tasks.append(
+                    gemini_single_file(
+                        message,
+                        chunk_file.name,
+                        model_id=model_id,
+                        prompt=prompt,
+                        start_seconds=int(start / sr),
+                        delete_local_file=True,
+                        delete_gemini_file=delete_gemini_file,
+                    )
+                )
+        results = await asyncio.gather(*tasks)
+        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+    except Exception as e:
+        logger.error(e)
+        return {"error": str(e)}
+    return transcription
+
+
+async def gemini_stream_asr(
+    client: Client,
+    message: Message,
+    path: str | Path,
+    voice_format: str,
+    model_id: str | None = None,
+    prompt: str = "请转录这段音频",
+    *,
+    silent: bool = False,
+    delete_gemini_file: bool = True,
+    **kwargs,
+) -> dict:
+    """(Deprecated) Gemini stream ASR.
+
+    https://ai.google.dev/gemini-api/docs/audio
+
+    Args:
+        silent (bool, optional): If Ture, do not update the status, return all results in the end.
+    """
+    system_instruction = """You are a transcription assistant tasked with converting audio files into text.
+
+Your output must follow these requirements:
+- Format each sentence as `[hh:mm:ss] sentence` with punctuation included, where `hh:mm:ss` is the start time of the sentence in the audio.
+- Omit the hour (`hh`) if it is zero, displaying only `mm:ss`.
+- Directly transcribe the audio content without any greetings or content unrelated to the audio itself.
+
+Steps:
+1. Listen to the audio file carefully and identify the start time of each sentence.
+2. Transcribe the audio content word-for-word, including punctuation, according to the specified format.
+3. Ensure that all time codes (hh:mm:ss or mm:ss) are precise.
+
+Output Format:
+- Each sentence should be formatted in a line as `[hh:mm:ss] sentence`.
+- Exclude any hour segment that equals zero, converting `[00:mm:ss]` to `[mm:ss]`.
+- Do not include any additional commentary or greetings.
+
+Example-1:
+- Input: Audio with content starting at 2 seconds.
+- Output: [00:02] 大家好, 我是小明, 欢迎来到我的频道。
+
+Example-2:
+- Input: Audio with content at 8 seconds and 1 hour, 12 minutes, and 32 seconds.
+- Output: [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
+[01:12:32] 谢谢大家收听。
+
+
+Notes:
+- Focus on accuracy in capturing both the timing and the spoken content.
+- Maintain consistent formatting to ensure clarity and readability."""
+    path = Path(path)
+    res = {}
+    sent_messages = []
+    status = None if silent else kwargs.get("progress")
+    api_keys = shuffle_keys(GEMINI.API_KEY)
+    if model_id is None:
+        model_id = GEMINI.ASR_MODEL
+    for api_key in api_keys.split(","):
+        try:
+            logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+            http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+            http_options = hook_gemini_httpoptions(http_options, message)
+            app = genai.Client(api_key=api_key, http_options=http_options)
+            uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+            genconfig = {}
+            with contextlib.suppress(Exception):
+                genconfig = json.loads(GEMINI.ASR_CONFIG)
+            genconfig |= {"response_modalities": ["TEXT"]}  # force text response
+            genconfig |= {"system_instruction": system_instruction}  # pin system instruction
+            if GEMINI.ASR_THINKING_BUDGET is not None:
+                thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+                genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+            if GEMINI.ASR_USE_GROUNDING:
+                genconfig |= {"tools": [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())]}
+            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+            params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+            res = await gemini_stream(
+                client,
+                message,
+                model_name="ASR",
+                params=params,
+                prefix="",
+                silent=silent,
+                max_retry=0,
+                gemini_api_key=api_key,
+                append_grounding=False,
+                **kwargs,
+            )
+            if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
+                continue
+            sent_messages = res.get("sent_messages", [])
+            break
+        except Exception as e:
+            logger.error(e)
+            with contextlib.suppress(Exception):
+                [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
+        finally:
+            with contextlib.suppress(Exception):
+                if "uploaded_audio" in locals() and uploaded_audio.name:
+                    if delete_gemini_file:
+                        await app.aio.files.delete(name=uploaded_audio.name)
+                    else:
+                        res["gemini_file"] = uploaded_audio
+    res["sent_messages"] = [status, *sent_messages]
+    return res
src/asr/gemini_asr.py
@@ -1,187 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-import contextlib
-import json
-import random
-from pathlib import Path
-
-from glom import glom
-from google import genai
-from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig, UrlContext
-from loguru import logger
-from pydantic import BaseModel
-from pyrogram.client import Client
-from pyrogram.types import Message
-
-from config import GEMINI
-from llm.gemini import gemini_stream
-from llm.hooks import hook_gemini_httpoptions
-from llm.utils import shuffle_keys
-from messages.progress import modify_progress
-from utils import count_subtitles
-
-
-async def gemini_stream_asr(
-    client: Client,
-    message: Message,
-    path: str | Path,
-    voice_format: str,
-    model_id: str | None = None,
-    prompt: str = "请转录这段音频",
-    *,
-    silent: bool = False,
-    delete_gemini_file: bool = True,
-    **kwargs,
-) -> dict:
-    """Gemini stream ASR.
-
-    https://ai.google.dev/gemini-api/docs/audio
-
-    Args:
-        silent (bool, optional): If Ture, do not update the status, return all results in the end.
-    """
-    system_instruction = """You are a transcription assistant tasked with converting audio files into text.
-
-Your output must follow these requirements:
-- Format each sentence as `[hh:mm:ss] sentence` with punctuation included, where `hh:mm:ss` is the start time of the sentence in the audio.
-- Omit the hour (`hh`) if it is zero, displaying only `mm:ss`.
-- Directly transcribe the audio content without any greetings or content unrelated to the audio itself.
-
-Steps:
-1. Listen to the audio file carefully and identify the start time of each sentence.
-2. Transcribe the audio content word-for-word, including punctuation, according to the specified format.
-3. Ensure that all time codes (hh:mm:ss or mm:ss) are precise.
-
-Output Format:
-- Each sentence should be formatted in a line as `[hh:mm:ss] sentence`.
-- Exclude any hour segment that equals zero, converting `[00:mm:ss]` to `[mm:ss]`.
-- Do not include any additional commentary or greetings.
-
-Example-1:
-- Input: Audio with content starting at 2 seconds.
-- Output: [00:02] 大家好, 我是小明, 欢迎来到我的频道。
-
-Example-2:
-- Input: Audio with content at 8 seconds and 1 hour, 12 minutes, and 32 seconds.
-- Output: [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
-[01:12:32] 谢谢大家收听。
-
-
-Notes:
-- Focus on accuracy in capturing both the timing and the spoken content.
-- Maintain consistent formatting to ensure clarity and readability."""
-    path = Path(path)
-    res = {}
-    sent_messages = []
-    status = None if silent else kwargs.get("progress")
-    api_keys = shuffle_keys(GEMINI.API_KEY)
-    if model_id is None:
-        model_id = GEMINI.ASR_MODEL
-    for api_key in api_keys.split(","):
-        try:
-            logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
-            http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
-            http_options = hook_gemini_httpoptions(http_options, message)
-            app = genai.Client(api_key=api_key, http_options=http_options)
-            uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
-            genconfig = {}
-            with contextlib.suppress(Exception):
-                genconfig = json.loads(GEMINI.ASR_CONFIG)
-            genconfig |= {"response_modalities": ["TEXT"]}  # force text response
-            genconfig |= {"system_instruction": system_instruction}  # pin system instruction
-            if GEMINI.ASR_THINKING_BUDGET is not None:
-                thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
-                genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
-            if GEMINI.ASR_USE_GROUNDING:
-                genconfig |= {"tools": [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())]}
-            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
-            params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
-            res = await gemini_stream(
-                client,
-                message,
-                model_name="ASR",
-                params=params,
-                prefix="",
-                silent=silent,
-                max_retry=0,
-                gemini_api_key=api_key,
-                append_grounding=False,
-                **kwargs,
-            )
-            if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
-                continue
-            sent_messages = res.get("sent_messages", [])
-            break
-        except Exception as e:
-            logger.error(e)
-            with contextlib.suppress(Exception):
-                [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
-        finally:
-            with contextlib.suppress(Exception):
-                if "uploaded_audio" in locals() and uploaded_audio.name:
-                    if delete_gemini_file:
-                        await app.aio.files.delete(name=uploaded_audio.name)
-                    else:
-                        res["gemini_file"] = uploaded_audio
-    res["sent_messages"] = [status, *sent_messages]
-    return res
-
-
-class Transcription(BaseModel):
-    start_minute: int
-    start_second: int
-    sentence_with_punctuation: str
-
-
-def generate_transcription(items: list[dict]) -> str:
-    res = ""
-    show_timestamp = False
-    for idx, item in enumerate(items):
-        sentence: str = item["sentence_with_punctuation"]
-        if not sentence:
-            continue
-
-        if idx == 0 or res.endswith((".", "。")):
-            show_timestamp = True
-        if show_timestamp:
-            res += f"\n[{item['start_minute']}:{item['start_second']:02d}] {sentence}"
-        else:
-            res += sentence
-    return res.strip()
-
-
-async def gemini_nonstream_asr(path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频") -> str:
-    """(Deprecated) Gemini ASR.
-
-    This function is deprecated and will be removed in the future.
-    Use `gemini_stream_asr` instead.
-
-    https://ai.google.dev/gemini-api/docs/audio
-    """
-    path = Path(path)
-    api_keys = [x.strip() for x in GEMINI.API_KEY.split(",") if x.strip()]
-    random.shuffle(api_keys)
-    res = ""
-    for key in api_keys:
-        try:
-            logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
-            client = genai.Client(api_key=key, http_options=HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY}))
-            uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
-            logger.debug(uploaded_audio)
-            contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
-            response = await client.aio.models.generate_content(
-                model=GEMINI.ASR_MODEL,
-                contents=contents,  # type: ignore
-                config=GenerateContentConfig(
-                    response_mime_type="application/json",
-                    response_schema=list[Transcription],
-                ),
-            )
-            if uploaded_audio.name:  # delete file once finished
-                client.files.delete(name=uploaded_audio.name)
-            if parsed := glom(response.model_dump(), "parsed"):
-                return generate_transcription(parsed)
-        except Exception as e:
-            logger.error(e)
-            res = str(e)
-    return res
src/asr/voice_recognition.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message
 from asr.ali_asr import ali_asr
 from asr.cloudflare import cloudflare_asr
 from asr.deepgram import deepgram_asr
-from asr.gemini_asr import gemini_stream_asr
+from asr.gemini import gemini_asr
 from asr.groq import groq_asr
 from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
 from asr.utils import get_asr_method
@@ -212,8 +212,6 @@ async def asr_file(
     if asr_method.startswith("tencent") and voice_format in ogg_names:
         voice_format = "ogg-opus"
         path = path.rename(path.with_stem(rand_string()))  # sanitize filename. (for Tencent Signature v3)
-    if asr_method == "gemini" and voice_format in ogg_names:
-        voice_format = "ogg"
 
     logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
     try:
@@ -229,7 +227,13 @@ async def asr_file(
         elif asr_method == "deepgram":
             res = await deepgram_asr(path)
         elif asr_method == "gemini":
-            res = await gemini_stream_asr(path=path, voice_format=voice_format, delete_gemini_file=delete_gemini_file, **kwargs)
+            res = await gemini_asr(
+                message=kwargs["message"],
+                path=path,
+                model_id=kwargs.get("gemini_asr_model_id", ""),
+                prompt=kwargs.get("gemini_asr_prompt", ""),
+                delete_gemini_file=delete_gemini_file,
+            )
         elif asr_method == "cloudflare":
             res = await cloudflare_asr(path=path, model=kwargs.get("cf_asr_model", ""), prompt=kwargs.get("cf_asr_prompt", ""))
         elif asr_method == "groq":
src/config.py
@@ -258,6 +258,8 @@ class ASR:
     DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
     CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "")  # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
+    GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "900"))  # split long audio file into chunks
+    GEMINI_OVERLAP_SECONDS = float(os.getenv("ASR_GEMINI_OVERLAP_SECONDS", "10"))  # overlap seconds between chunks
     GROQ_PROXY = os.getenv("ASR_GROQ_PROXY", None)  # Ban CN & HK IP
     GROQ_MAX_BYTES = int(os.getenv("ASR_GROQ_MAX_BYTES", "26214400"))  # 25MB (max file bytes for single file)
     GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "1800"))  # split long audio file into chunks