Commit d4feaa6
Changed files (3)
src
preview
src/asr/gemini_asr.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import io
import random
from pathlib import Path
@@ -8,8 +9,15 @@ from google import genai
from google.genai.types import GenerateContentConfig, HttpOptions, UploadFileConfig
from loguru import logger
from pydantic import BaseModel
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message, ReplyParameters
-from config import ASR
+from config import ASR, TEXT_LENGTH
+from llm.gemini import parse_response
+from llm.utils import beautify_llm_response
+from messages.progress import modify_progress
+from messages.utils import count_without_entities, smart_split
class Transcription(BaseModel):
@@ -39,7 +47,6 @@ async def gemini_asr(path: str | Path, voice_format: str) -> str:
config=GenerateContentConfig(
response_mime_type="application/json",
response_schema=list[Transcription],
- temperature=0,
),
)
if parsed := glom(response.model_dump(), "parsed"):
@@ -65,3 +72,69 @@ def generate_transcription(items: list[dict]) -> str:
else:
res += sentence
return res.strip()
+
+
+async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, slient: bool = False, **kwargs) -> dict:
+ """Gemini stream ASR.
+
+ https://ai.google.dev/gemini-api/docs/audio
+
+ Args:
+ slient (bool, optional): If Ture, do not update the status, return all results in the end.
+ """
+ prompt = """请转录这段音频, 要求:
+ 1. 以 `[mm:ss] sentence` 格式输出句子内容, 包括标点符号。其中mm:ss为此句话开始时间的分钟和秒
+ 2. 请使用简体中文输出
+ 3. 直接输出音频转录内容, 不要输出任何与音频内容无关的寒暄问候
+
+ 输出实例:
+ [00:02] 大家好, 我是小明, 欢迎来到我的频道。
+ [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
+ """
+
+ def warp(s: str) -> str:
+ return BLOCKQUOTE_EXPANDABLE_DELIM + s + BLOCKQUOTE_EXPANDABLE_END_DELIM
+
+ path = Path(path)
+ api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
+ transcriptions = ""
+ runtime_texts = ""
+ sent_messages = []
+ if status := kwargs.get("progress"):
+ sent_messages.append(status)
+ if slient:
+ status = None
+ try:
+ logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
+ app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
+ uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+ logger.debug(uploaded_audio)
+ async for chunk in await app.aio.models.generate_content_stream(
+ model=ASR.GEMINI_MODEL,
+ contents=[prompt, uploaded_audio],
+ ):
+ resp = parse_response(chunk.model_dump())
+ sentence = resp.get("texts", "")
+ transcriptions += sentence
+ runtime_texts += sentence
+ runtime_texts = beautify_llm_response(runtime_texts)
+ if await count_without_entities(runtime_texts) <= TEXT_LENGTH:
+ if len(runtime_texts) > 5: # start response if sentence is not empty
+ await modify_progress(message=status, text=runtime_texts, detail_progress=True)
+ else: # transcriptions is too long, split it into multiple messages
+ parts = await smart_split(runtime_texts)
+ await modify_progress(message=status, text=warp(parts[0]), force_update=True) # force send the first part
+ runtime_texts = parts[-1] # keep the last part
+ if not slient:
+ status = await client.send_message(message.chat.id, runtime_texts) # the new message
+ sent_messages.append(status)
+
+ # all chunks are processed
+ await modify_progress(message=status, text=warp(beautify_llm_response(runtime_texts)), force_update=True)
+ if len(sent_messages) > 1:
+ with io.BytesIO(transcriptions.encode("utf-8")) as f:
+ await client.send_document(message.chat.id, f, file_name="语音识别结果.txt", reply_parameters=ReplyParameters(message_id=message.id))
+ [await modify_progress(msg, del_status=True) for msg in sent_messages]
+ except Exception as e:
+ logger.error(e)
+ return {"texts": transcriptions} if slient else {}
src/asr/voice_recognition.py
@@ -12,7 +12,7 @@ from pyrogram.client import Client
from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
from pyrogram.types import Message
-from asr.gemini_asr import gemini_asr
+from asr.gemini_asr import gemini_asr, gemini_stream_asr
from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
from asr.utils import get_asr_method
from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
@@ -120,7 +120,7 @@ async def voice_to_text(
msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
logger.info(msg)
if kwargs.get("show_progress"):
- res = await send2tg(client, message, texts=msg, **kwargs)
+ res = await send2tg(client, trigger_message, texts=msg, **kwargs)
kwargs["progress"] = res[0]
path: str | Path = await trigger_message.download() # type: ignore
@@ -131,7 +131,7 @@ async def voice_to_text(
await modify_progress(text=msg, force_update=True, **kwargs)
return
- res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language)
+ res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language, client=client, message=trigger_message, **kwargs)
if error := res.get("error"):
await modify_progress(text=error, force_update=True, **kwargs)
return
@@ -161,6 +161,9 @@ async def asr_file(
engine: str = "",
duration: int = 0,
language: str = "16k_zh-PY",
+ *,
+ gemini_stream_mode: bool = True,
+ **kwargs,
) -> dict:
"""Get ASR results of an audio file."""
res = {}
@@ -216,6 +219,8 @@ async def asr_file(
else:
texts = glom(result, "Response.Data.ErrorMsg")
res["error"] = texts
+ elif asr_method == "gemini" and gemini_stream_mode:
+ return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
elif asr_method == "gemini":
texts = await gemini_asr(path, voice_format)
res["texts"] = texts
src/preview/ytdlp.py
@@ -240,7 +240,7 @@ async def preview_ytdlp(
append_transcription = False # disable asr transcription
if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
- asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration)
+ asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, slient=True)
if texts := asr_res.get("texts"):
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
with io.BytesIO(texts.encode("utf-8")) as f: