Commit a46a7bc

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-27 14:06:38
feat(ytdlp): append transcription via ASR
1 parent e753921
Changed files (4)
src/asr/gemini_asr.py
@@ -24,15 +24,12 @@ async def gemini_asr(path: str | Path, voice_format: str) -> str:
     https://ai.google.dev/gemini-api/docs/audio
     """
     path = Path(path)
-    if voice_format == "ogg-opus":
-        voice_format = "ogg"
-
     api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
     random.shuffle(api_keys)
     res = ""
     for key in api_keys:
         try:
-            logger.debug(f"ASR {path.as_posix()} via {ASR.GEMINI_BASR_URL}, proxy={ASR.GEMINI_PROXY}")
+            logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
             client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
             uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
             logger.debug(uploaded_audio)
src/asr/utils.py
@@ -5,7 +5,7 @@
 from config import ASR, FILE_SERVER
 
 
-def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
+def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
     """Get ASR method and supported file types."""
     if duration < 60:
         asr_engine = ASR.SHORT_ENGINE
@@ -14,9 +14,9 @@ def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
     else:
         asr_engine = ASR.LONG_ENGINE
 
-    if asr_engine.lower() == "tencent":
+    if asr_engine == "tencent" or force_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
-    if asr_engine.lower() == "gemini":
+    if asr_engine.lower() == "gemini" or force_engine == "gemini":
         return get_gemini_asr_method(duration)
     return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
 
src/asr/voice_recognition.py
@@ -9,17 +9,18 @@ from pathlib import Path
 from glom import glom
 from loguru import logger
 from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 from pyrogram.types import Message
 
 from asr.gemini_asr import gemini_asr
 from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
 from asr.utils import get_asr_method
-from config import CAPTION_LENGTH, FILE_SERVER, PREFIX
+from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
-from multimedia import convert_to_audio
+from multimedia import convert_to_audio, parse_media_info
 from utils import rand_string, to_int
 
 # ruff: noqa: RUF001
@@ -104,25 +105,23 @@ async def voice_to_text(
     this_info = parse_msg(message, silent=True)
     trigger_info = parse_msg(trigger_message, silent=True)
 
-    asr_engine = "16k_zh-PY"  # default: 中英粤
+    asr_language = "16k_zh-PY"  # default: 中英粤
+    force_engine = ""  # gemini or tencent
     if matched := re.match(r"/asr\s+([^.。,,/\s]+)", this_info["text"]):  # /asr yue
-        asr_engine = f"16k_{matched.group(1)}"
-    asr_engine = asr_engine.replace("16k_fy", "16k_zh_dialect")  # fix dialect engine code
+        custom_code = matched.group(1)
+        if custom_code == "fy":  # re-map dialect
+            custom_code = "zh_dialect"
+        custom_code = custom_code.replace("fy", "zh_dialect")
+        if f"16k_{custom_code}" in ENGINE_MAP:
+            asr_language = f"16k_{custom_code}"
+        elif custom_code in ["gemini", "tencent"]:
+            force_engine = custom_code
 
-    duration = trigger_info["duration"]
-    asr_method, supported_ext = get_asr_method(duration, trigger_info["file_size"])
-    if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr", "gemini"]:
-        await modify_progress(text=asr_method, force_update=True, **kwargs)
-        return
-
-    msg = f"Recieved {trigger_info['mtype']} message, start recognizing by {ENGINE_MAP.get(asr_engine, 'Unknown')}..."
+    msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
     logger.info(msg)
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    if asr_method != "gemini" and asr_engine not in ENGINE_MAP:
-        await modify_progress(text=f"Unsupported ASR engine: {asr_engine}", force_update=True, **kwargs)
-        return
 
     path: str | Path = await trigger_message.download()  # type: ignore
     path = Path(path).expanduser().resolve()
@@ -131,23 +130,13 @@ async def voice_to_text(
         logger.error(msg)
         await modify_progress(text=msg, force_update=True, **kwargs)
         return
-    voice_format = path.suffix.lstrip(".")
-    if voice_format not in supported_ext:
-        path = convert_to_audio(path, ext="aac", codec="aac")
-        voice_format = "aac"
-    asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size)  # match again based on converted file
-    path = path.rename(path.with_stem(rand_string()))  # sanitize filename. (for Tencent Signature v3)
-
-    if voice_format in ["oga", "ogg", "opus"]:  # rename format
-        voice_format = "ogg-opus"
 
-    logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
-    res = await asr_file(path, asr_method, voice_format, asr_engine)
+    res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language)
     if error := res.get("error"):
         await modify_progress(text=error, force_update=True, **kwargs)
         return
     if texts := res.get("texts"):
-        final = f"{BEGINNING}\n{texts}"
+        final = f"{BEGINNING}\n{BLOCKQUOTE_EXPANDABLE_DELIM}{texts}{BLOCKQUOTE_EXPANDABLE_END_DELIM}"
         logger.success(f"{final!r}")
         # send results
         target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -155,9 +144,10 @@ async def voice_to_text(
         length = await count_without_entities(final)
         if length < CAPTION_LENGTH:  # short
             await client.copy_message(chat_id=to_int(target_chat), from_chat_id=trigger_info["cid"], message_id=trigger_info["mid"], caption=final, reply_parameters=reply_parameters)
+        elif length < TEXT_LENGTH:  # middle
+            await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
         else:  # long
-            final = final.removeprefix(f"{BEGINNING}\n")
-            with io.BytesIO(final.encode("utf-8")) as f:
+            with io.BytesIO(texts.encode("utf-8")) as f:
                 await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", reply_parameters=reply_parameters)
         await modify_progress(del_status=True, **kwargs)
 
@@ -166,21 +156,51 @@ async def voice_to_text(
             await message.delete()
 
 
-async def asr_file(path: str | Path, method: str, voice_format: str, asr_engine: str = "16k_zh-PY") -> dict:
-    """Get ASR texts of an audio file."""
+async def asr_file(
+    path: str | Path,
+    engine: str = "",
+    duration: int = 0,
+    language: str = "16k_zh-PY",
+) -> dict:
+    """Get ASR results of an audio file."""
     res = {}
     path = Path(path).expanduser().resolve()
     if not path.is_file():
         return {"error": f"{path} is not exist"}
+    info = parse_media_info(path)
+    if duration == 0:
+        duration = info["duration"]
+    asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
+    if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr", "gemini"]:
+        return {"error": asr_method}
+
+    voice_format = path.suffix.lstrip(".")
+    if voice_format not in supported_ext:
+        if info["audio_codec"].split("/")[-1] in supported_ext and not info["video_codec"]:
+            voice_format = info["audio_codec"].split("/")[-1]
+        else:
+            path = convert_to_audio(path, ext="aac", codec="aac")
+            voice_format = "aac"
+    # match again based on converted file
+    asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
+
+    ogg_names = ["oga", "ogg-opus", "ogg", "opus"]  # unify format name
+    if asr_method in ["single_sentence_asr", "flash_asr", "async_asr"] and voice_format in ogg_names:
+        voice_format = "ogg-opus"
+        path = path.rename(path.with_stem(rand_string()))  # sanitize filename. (for Tencent Signature v3)
+    if asr_method == "gemini" and voice_format in ogg_names:
+        voice_format = "ogg"
+
+    logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
     try:
-        if method == "single_sentence_asr":
-            resp = await single_sentence_asr(path, asr_engine, voice_format)
+        if asr_method == "single_sentence_asr":
+            resp = await single_sentence_asr(path, language, voice_format)
             texts = glom(resp, "Response.Result").replace("。", "。\n")
-        elif method == "flash_asr":
-            resp = await flash_asr(path, asr_engine, voice_format)
+        elif asr_method == "flash_asr":
+            resp = await flash_asr(path, language, voice_format)
             texts = glom(resp, "flash_result.0.text").replace("。", "。\n")
-        elif method == "async_asr":
-            resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
+        elif asr_method == "async_asr":
+            resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", language)
             task_id = resp["Response"]["Data"]["TaskId"]
             logger.success(f"ASR任务提交成功, TaskID: {task_id}")
             result = await query_async_asr(task_id)
@@ -196,7 +216,7 @@ async def asr_file(path: str | Path, method: str, voice_format: str, asr_engine:
             else:
                 texts = glom(result, "Response.Data.ErrorMsg")
                 res["error"] = texts
-        elif method == "gemini":
+        elif asr_method == "gemini":
             texts = await gemini_asr(path, voice_format)
         res["texts"] = texts
         logger.success(f"{texts!r}")
src/preview/ytdlp.py
@@ -19,7 +19,8 @@ from pyrogram.types import Message
 from yt_dlp import YoutubeDL
 from yt_dlp.utils import DownloadError, ExtractorError, YoutubeDLError
 
-from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, MAX_FILE_BYTES, PROVIDER, PROXY, TID, TOKEN, YTDLP_DOWNLOAD_MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES, cache
+from asr.voice_recognition import asr_file
+from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, MAX_FILE_BYTES, PROVIDER, PROXY, READING_SPEED, TID, TOKEN, YTDLP_DOWNLOAD_MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES, cache
 from cookies import cookie_cloud_bilibili
 from database import get_db
 from messages.database import copy_messages_from_db, save_messages
@@ -51,6 +52,8 @@ async def preview_ytdlp(
     youtube_comments_provider: str = PROVIDER.YOUTUBE_COMMENTS,
     proxy: str | None = None,
     append_youtube_subtitle: bool = True,
+    append_transcription: bool = True,
+    ytdlp_transcription_engine: str = "gemini",
     **kwargs,
 ):
     """Preview ytdlp link in the message.
@@ -66,6 +69,8 @@ async def preview_ytdlp(
         youtube_comments_provider (str, optional): The youtube comments extractor: "free" or "false".
         proxy (str, optional): Proxy to use. Defaults to None.
         append_youtube_subtitle (bool, optional): Also send youtube subtitle.
+        append_transcription (bool, optional): Also append transcription.
+        ytdlp_transcription_method (str, optional): Method to get transcription.
     """
     logger.trace(f"{url=} {kwargs=}")
     if kwargs.get("show_progress") and "progress" not in kwargs:
@@ -232,6 +237,15 @@ async def preview_ytdlp(
             caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
             with io.BytesIO(subtitles.encode("utf-8")) as f:
                 await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+                append_transcription = False  # disable asr transcription
+
+    if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
+        asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration)
+        if texts := asr_res.get("texts"):
+            caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
+            with io.BytesIO(texts.encode("utf-8")) as f:
+                await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+
     Path(json_file).unlink(missing_ok=True)
     cleanup_ytdlp(info["id"])