Commit 2bbe68b
Changed files (4)
src/asr/gemini.py
@@ -0,0 +1,299 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import contextlib
+import json
+import tempfile
+from pathlib import Path
+
+import soundfile as sf
+from glom import glom
+from google import genai
+from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig, UrlContext
+from loguru import logger
+from pydantic import BaseModel, Field
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from asr.groq import merge_transcripts
+from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
+from config import ASR, GEMINI
+from llm.gemini import gemini_stream
+from llm.hooks import hook_gemini_httpoptions
+from llm.utils import shuffle_keys
+from messages.progress import modify_progress
+from utils import count_subtitles, guess_mime, seconds_to_time, strings_list, zhcn
+
+
+class Transcription(BaseModel):
+ start: int = Field(description="start time in seconds of the sentence in the audio")
+ sentence: str = Field(description="transcription sentence with punctuation")
+ end: int = Field(description="end time in seconds of the sentence in the audio")
+
+
+async def gemini_asr(
+ message: Message,
+ path: str | Path,
+ model_id: str = "",
+ prompt: str = "请转录这段音频",
+ *,
+ delete_gemini_file: bool = True,
+) -> dict:
+ """Gemini stream ASR.
+
+ https://ai.google.dev/gemini-api/docs/audio
+
+ Args:
+ silent (bool, optional): If Ture, do not update the status, return all results in the end.
+ """
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "error": "File not found."}
+ audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
+ duration = audio_duration(audio_path)
+ if duration < ASR.GEMINI_CHUNK_SECONDS:
+ return await gemini_single_file(message, audio_path, model_id=model_id, prompt=prompt, delete_gemini_file=delete_gemini_file)
+ return await gemini_file_chunks(message, audio_path, model_id=model_id, prompt=prompt, delete_gemini_file=delete_gemini_file)
+
+
+async def gemini_single_file(
+ message: Message,
+ path: str | Path,
+ model_id: str = "",
+ prompt: str = "",
+ *,
+ start_seconds: int = 0,
+ delete_local_file: bool = False,
+ delete_gemini_file: bool = True,
+) -> dict:
+ """Gemini stream ASR.
+
+ https://ai.google.dev/gemini-api/docs/audio
+
+ Returns:
+ {"texts": str, "raw_texts": str, "segments": list[dict]}
+ """
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "raw_texts": "", "segments": [], "error": "File not found."}
+ res = {}
+ if not model_id:
+ model_id = GEMINI.ASR_MODEL
+ for api_key in strings_list(GEMINI.API_KEY, shuffle=True):
+ try:
+ logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+ http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=guess_mime(path)))
+ genconfig = {}
+ with contextlib.suppress(Exception):
+ genconfig = json.loads(GEMINI.ASR_CONFIG)
+ genconfig |= {"response_mime_type": "application/json", "response_schema": list[Transcription]}
+ if GEMINI.ASR_THINKING_BUDGET is not None:
+ thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+ genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+ contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+ params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+ answers = "" # all model responses
+ async for chunk in await app.aio.models.generate_content_stream(**params):
+ text = glom(chunk.model_dump(), "candidates.0.content.parts.0.text", default="") or ""
+ logger.trace(f"{text!r}")
+ answers += text
+ try:
+ transcriptions = json.loads(answers)
+ except json.JSONDecodeError as e:
+ logger.error(f"{e}\n{answers}")
+ continue
+ res["segments"] = [
+ {
+ "start": start_seconds + seg["start"],
+ "end": start_seconds + seg["end"],
+ "text": zhcn(seg["sentence"]),
+ }
+ for seg in transcriptions
+ ]
+ res["raw_texts"] = " ".join(x["text"] for x in res["segments"])
+ res["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in res["segments"]) # with timestamp
+ break
+ except Exception as e:
+ logger.error(e)
+ finally:
+ if delete_local_file:
+ path.unlink(missing_ok=True)
+ with contextlib.suppress(Exception):
+ if "uploaded_audio" in locals() and uploaded_audio.name:
+ if delete_gemini_file:
+ await app.aio.files.delete(name=uploaded_audio.name)
+ else:
+ res["gemini_file"] = uploaded_audio
+ return res
+
+
+async def gemini_file_chunks(
+ message: Message,
+ path: str | Path,
+ chunk_seconds: float = ASR.GEMINI_CHUNK_SECONDS,
+ overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
+ model_id: str = "",
+ prompt: str = "",
+ *,
+ delete_gemini_file: bool = True,
+) -> dict:
+ """Transcribe audio in chunks with overlap.
+
+ Args:
+ path: Path to audio file
+ chunk_seconds: Length of each chunk in seconds
+ overlap_seconds: Overlap between chunks in seconds
+
+ Returns:
+ dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
+ """
+ path = Path(path).expanduser().resolve()
+ with sf.SoundFile(path, "r") as f:
+ sr = f.samplerate
+ audio = f.read(dtype="float32")
+ duration_seconds = len(audio) / sr
+ logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
+
+ transcription = {}
+ try:
+ # Calculate # of chunks
+ total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+ total_chunks = int(total_chunks)
+ tasks = []
+ # Loop through each chunk, extract current chunk from audio, transcribe
+ for i in range(total_chunks):
+ start = int(i * (chunk_seconds - overlap_seconds) * sr)
+ end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+ logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
+ chunk = audio[start:end]
+ if chunk.shape[0] == 0: # empty chunk
+ continue
+ with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
+ sf.write(chunk_file.name, chunk, sr, format="ogg", subtype="OPUS")
+ tasks.append(
+ gemini_single_file(
+ message,
+ chunk_file.name,
+ model_id=model_id,
+ prompt=prompt,
+ start_seconds=int(start / sr),
+ delete_local_file=True,
+ delete_gemini_file=delete_gemini_file,
+ )
+ )
+ results = await asyncio.gather(*tasks)
+ transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return transcription
+
+
+async def gemini_stream_asr(
+ client: Client,
+ message: Message,
+ path: str | Path,
+ voice_format: str,
+ model_id: str | None = None,
+ prompt: str = "请转录这段音频",
+ *,
+ silent: bool = False,
+ delete_gemini_file: bool = True,
+ **kwargs,
+) -> dict:
+ """(Deprecated) Gemini stream ASR.
+
+ https://ai.google.dev/gemini-api/docs/audio
+
+ Args:
+ silent (bool, optional): If Ture, do not update the status, return all results in the end.
+ """
+ system_instruction = """You are a transcription assistant tasked with converting audio files into text.
+
+Your output must follow these requirements:
+- Format each sentence as `[hh:mm:ss] sentence` with punctuation included, where `hh:mm:ss` is the start time of the sentence in the audio.
+- Omit the hour (`hh`) if it is zero, displaying only `mm:ss`.
+- Directly transcribe the audio content without any greetings or content unrelated to the audio itself.
+
+Steps:
+1. Listen to the audio file carefully and identify the start time of each sentence.
+2. Transcribe the audio content word-for-word, including punctuation, according to the specified format.
+3. Ensure that all time codes (hh:mm:ss or mm:ss) are precise.
+
+Output Format:
+- Each sentence should be formatted in a line as `[hh:mm:ss] sentence`.
+- Exclude any hour segment that equals zero, converting `[00:mm:ss]` to `[mm:ss]`.
+- Do not include any additional commentary or greetings.
+
+Example-1:
+- Input: Audio with content starting at 2 seconds.
+- Output: [00:02] 大家好, 我是小明, 欢迎来到我的频道。
+
+Example-2:
+- Input: Audio with content at 8 seconds and 1 hour, 12 minutes, and 32 seconds.
+- Output: [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
+[01:12:32] 谢谢大家收听。
+
+
+Notes:
+- Focus on accuracy in capturing both the timing and the spoken content.
+- Maintain consistent formatting to ensure clarity and readability."""
+ path = Path(path)
+ res = {}
+ sent_messages = []
+ status = None if silent else kwargs.get("progress")
+ api_keys = shuffle_keys(GEMINI.API_KEY)
+ if model_id is None:
+ model_id = GEMINI.ASR_MODEL
+ for api_key in api_keys.split(","):
+ try:
+ logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+ http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+ genconfig = {}
+ with contextlib.suppress(Exception):
+ genconfig = json.loads(GEMINI.ASR_CONFIG)
+ genconfig |= {"response_modalities": ["TEXT"]} # force text response
+ genconfig |= {"system_instruction": system_instruction} # pin system instruction
+ if GEMINI.ASR_THINKING_BUDGET is not None:
+ thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+ genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+ if GEMINI.ASR_USE_GROUNDING:
+ genconfig |= {"tools": [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())]}
+ contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+ params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+ res = await gemini_stream(
+ client,
+ message,
+ model_name="ASR",
+ params=params,
+ prefix="",
+ silent=silent,
+ max_retry=0,
+ gemini_api_key=api_key,
+ append_grounding=False,
+ **kwargs,
+ )
+ if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
+ continue
+ sent_messages = res.get("sent_messages", [])
+ break
+ except Exception as e:
+ logger.error(e)
+ with contextlib.suppress(Exception):
+ [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
+ finally:
+ with contextlib.suppress(Exception):
+ if "uploaded_audio" in locals() and uploaded_audio.name:
+ if delete_gemini_file:
+ await app.aio.files.delete(name=uploaded_audio.name)
+ else:
+ res["gemini_file"] = uploaded_audio
+ res["sent_messages"] = [status, *sent_messages]
+ return res
src/asr/gemini_asr.py
@@ -1,187 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-import contextlib
-import json
-import random
-from pathlib import Path
-
-from glom import glom
-from google import genai
-from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig, UrlContext
-from loguru import logger
-from pydantic import BaseModel
-from pyrogram.client import Client
-from pyrogram.types import Message
-
-from config import GEMINI
-from llm.gemini import gemini_stream
-from llm.hooks import hook_gemini_httpoptions
-from llm.utils import shuffle_keys
-from messages.progress import modify_progress
-from utils import count_subtitles
-
-
-async def gemini_stream_asr(
- client: Client,
- message: Message,
- path: str | Path,
- voice_format: str,
- model_id: str | None = None,
- prompt: str = "请转录这段音频",
- *,
- silent: bool = False,
- delete_gemini_file: bool = True,
- **kwargs,
-) -> dict:
- """Gemini stream ASR.
-
- https://ai.google.dev/gemini-api/docs/audio
-
- Args:
- silent (bool, optional): If Ture, do not update the status, return all results in the end.
- """
- system_instruction = """You are a transcription assistant tasked with converting audio files into text.
-
-Your output must follow these requirements:
-- Format each sentence as `[hh:mm:ss] sentence` with punctuation included, where `hh:mm:ss` is the start time of the sentence in the audio.
-- Omit the hour (`hh`) if it is zero, displaying only `mm:ss`.
-- Directly transcribe the audio content without any greetings or content unrelated to the audio itself.
-
-Steps:
-1. Listen to the audio file carefully and identify the start time of each sentence.
-2. Transcribe the audio content word-for-word, including punctuation, according to the specified format.
-3. Ensure that all time codes (hh:mm:ss or mm:ss) are precise.
-
-Output Format:
-- Each sentence should be formatted in a line as `[hh:mm:ss] sentence`.
-- Exclude any hour segment that equals zero, converting `[00:mm:ss]` to `[mm:ss]`.
-- Do not include any additional commentary or greetings.
-
-Example-1:
-- Input: Audio with content starting at 2 seconds.
-- Output: [00:02] 大家好, 我是小明, 欢迎来到我的频道。
-
-Example-2:
-- Input: Audio with content at 8 seconds and 1 hour, 12 minutes, and 32 seconds.
-- Output: [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
-[01:12:32] 谢谢大家收听。
-
-
-Notes:
-- Focus on accuracy in capturing both the timing and the spoken content.
-- Maintain consistent formatting to ensure clarity and readability."""
- path = Path(path)
- res = {}
- sent_messages = []
- status = None if silent else kwargs.get("progress")
- api_keys = shuffle_keys(GEMINI.API_KEY)
- if model_id is None:
- model_id = GEMINI.ASR_MODEL
- for api_key in api_keys.split(","):
- try:
- logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
- http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
- uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
- genconfig = {}
- with contextlib.suppress(Exception):
- genconfig = json.loads(GEMINI.ASR_CONFIG)
- genconfig |= {"response_modalities": ["TEXT"]} # force text response
- genconfig |= {"system_instruction": system_instruction} # pin system instruction
- if GEMINI.ASR_THINKING_BUDGET is not None:
- thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
- if GEMINI.ASR_USE_GROUNDING:
- genconfig |= {"tools": [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())]}
- contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
- params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
- res = await gemini_stream(
- client,
- message,
- model_name="ASR",
- params=params,
- prefix="",
- silent=silent,
- max_retry=0,
- gemini_api_key=api_key,
- append_grounding=False,
- **kwargs,
- )
- if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
- continue
- sent_messages = res.get("sent_messages", [])
- break
- except Exception as e:
- logger.error(e)
- with contextlib.suppress(Exception):
- [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
- finally:
- with contextlib.suppress(Exception):
- if "uploaded_audio" in locals() and uploaded_audio.name:
- if delete_gemini_file:
- await app.aio.files.delete(name=uploaded_audio.name)
- else:
- res["gemini_file"] = uploaded_audio
- res["sent_messages"] = [status, *sent_messages]
- return res
-
-
-class Transcription(BaseModel):
- start_minute: int
- start_second: int
- sentence_with_punctuation: str
-
-
-def generate_transcription(items: list[dict]) -> str:
- res = ""
- show_timestamp = False
- for idx, item in enumerate(items):
- sentence: str = item["sentence_with_punctuation"]
- if not sentence:
- continue
-
- if idx == 0 or res.endswith((".", "。")):
- show_timestamp = True
- if show_timestamp:
- res += f"\n[{item['start_minute']}:{item['start_second']:02d}] {sentence}"
- else:
- res += sentence
- return res.strip()
-
-
-async def gemini_nonstream_asr(path: str | Path, voice_format: str, *, prompt: str = "请转录这段音频") -> str:
- """(Deprecated) Gemini ASR.
-
- This function is deprecated and will be removed in the future.
- Use `gemini_stream_asr` instead.
-
- https://ai.google.dev/gemini-api/docs/audio
- """
- path = Path(path)
- api_keys = [x.strip() for x in GEMINI.API_KEY.split(",") if x.strip()]
- random.shuffle(api_keys)
- res = ""
- for key in api_keys:
- try:
- logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
- client = genai.Client(api_key=key, http_options=HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY}))
- uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
- logger.debug(uploaded_audio)
- contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
- response = await client.aio.models.generate_content(
- model=GEMINI.ASR_MODEL,
- contents=contents, # type: ignore
- config=GenerateContentConfig(
- response_mime_type="application/json",
- response_schema=list[Transcription],
- ),
- )
- if uploaded_audio.name: # delete file once finished
- client.files.delete(name=uploaded_audio.name)
- if parsed := glom(response.model_dump(), "parsed"):
- return generate_transcription(parsed)
- except Exception as e:
- logger.error(e)
- res = str(e)
- return res
src/asr/voice_recognition.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message
from asr.ali_asr import ali_asr
from asr.cloudflare import cloudflare_asr
from asr.deepgram import deepgram_asr
-from asr.gemini_asr import gemini_stream_asr
+from asr.gemini import gemini_asr
from asr.groq import groq_asr
from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
from asr.utils import get_asr_method
@@ -212,8 +212,6 @@ async def asr_file(
if asr_method.startswith("tencent") and voice_format in ogg_names:
voice_format = "ogg-opus"
path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
- if asr_method == "gemini" and voice_format in ogg_names:
- voice_format = "ogg"
logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
@@ -229,7 +227,13 @@ async def asr_file(
elif asr_method == "deepgram":
res = await deepgram_asr(path)
elif asr_method == "gemini":
- res = await gemini_stream_asr(path=path, voice_format=voice_format, delete_gemini_file=delete_gemini_file, **kwargs)
+ res = await gemini_asr(
+ message=kwargs["message"],
+ path=path,
+ model_id=kwargs.get("gemini_asr_model_id", ""),
+ prompt=kwargs.get("gemini_asr_prompt", ""),
+ delete_gemini_file=delete_gemini_file,
+ )
elif asr_method == "cloudflare":
res = await cloudflare_asr(path=path, model=kwargs.get("cf_asr_model", ""), prompt=kwargs.get("cf_asr_prompt", ""))
elif asr_method == "groq":
src/config.py
@@ -258,6 +258,8 @@ class ASR:
DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "") # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
+ GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "900")) # split long audio file into chunks
+ GEMINI_OVERLAP_SECONDS = float(os.getenv("ASR_GEMINI_OVERLAP_SECONDS", "10")) # overlap seconds between chunks
GROQ_PROXY = os.getenv("ASR_GROQ_PROXY", None) # Ban CN & HK IP
GROQ_MAX_BYTES = int(os.getenv("ASR_GROQ_MAX_BYTES", "26214400")) # 25MB (max file bytes for single file)
GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "1800")) # split long audio file into chunks