Commit d9b0f48

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-29 19:14:02
feat(subtitle): use ASR if YouTube subtitles are disabled
1 parent 51c7615
Changed files (4)
src/asr/gemini_asr.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import io
 import random
 from pathlib import Path
 
@@ -10,7 +9,7 @@ from google.genai.types import GenerateContentConfig, HttpOptions, UploadFileCon
 from loguru import logger
 from pydantic import BaseModel
 from pyrogram.client import Client
-from pyrogram.types import Message, ReplyParameters
+from pyrogram.types import Message
 
 from config import ASR, TEXT_LENGTH
 from llm.gemini import parse_response
@@ -120,16 +119,12 @@ async def gemini_stream_asr(client: Client, message: Message, path: str | Path,
                 parts = await smart_split(runtime_texts)
                 await modify_progress(message=status, text=blockquote(parts[0]), force_update=True)  # force send the first part
                 runtime_texts = parts[-1]  # keep the last part
-                if not slient:
+                if not status:
                     status = await client.send_message(message.chat.id, runtime_texts)  # the new message
                     sent_messages.append(status)
 
         # all chunks are processed
         await modify_progress(message=status, text=blockquote(beautify_llm_response(runtime_texts)), force_update=True)
-        if len(sent_messages) > 1:
-            with io.BytesIO(transcriptions.encode("utf-8")) as f:
-                await client.send_document(message.chat.id, f, file_name="语音识别结果.txt", reply_parameters=ReplyParameters(message_id=message.id))
-            [await modify_progress(msg, del_status=True) for msg in sent_messages]
     except Exception as e:
         logger.error(e)
-    return {"texts": transcriptions} if slient else {}
+    return {"texts": transcriptions, "sent_messages": sent_messages}
src/asr/voice_recognition.py
@@ -9,7 +9,6 @@ from pathlib import Path
 from glom import glom
 from loguru import logger
 from pyrogram.client import Client
-from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 from pyrogram.types import Message
 
 from asr.gemini_asr import gemini_asr, gemini_stream_asr
@@ -19,7 +18,7 @@ from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
-from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
+from messages.utils import blockquote, count_without_entities, equal_prefix, get_reply_to, startswith_prefix
 from multimedia import convert_to_audio, parse_media_info
 from utils import rand_string, to_int
 
@@ -136,7 +135,7 @@ async def voice_to_text(
         await modify_progress(text=error, force_update=True, **kwargs)
         return
     if texts := res.get("texts"):
-        final = f"{BEGINNING}\n{BLOCKQUOTE_EXPANDABLE_DELIM}{texts}{BLOCKQUOTE_EXPANDABLE_END_DELIM}"
+        final = f"{BEGINNING}\n{blockquote(texts)}"
         logger.success(f"{final!r}")
         # send results
         target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -152,6 +151,7 @@ async def voice_to_text(
         await modify_progress(del_status=True, **kwargs)
 
     with contextlib.suppress(Exception):
+        [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
         if this_info["mtype"] == "text":
             await message.delete()
 
src/others/subtitle.py
@@ -10,6 +10,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 from youtube_transcript_api import YouTubeTranscriptApi  # type: ignore
 
+from asr.voice_recognition import asr_file
 from config import API, PREFIX, PROVIDER, PROXY, READING_SPEED, TOKEN, TZ
 from database import cache
 from messages.parser import parse_msg
@@ -43,19 +44,36 @@ async def get_subtitle(client: Client, message: Message, youtube_subtitle_provid
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
+
     # cache media_group message
     if media_group_id := message.media_group_id:
         if cache.get(f"subtitle-{message.chat.id}-{media_group_id}"):
             return
         cache.set(f"subtitle-{message.chat.id}-{media_group_id}", "1", ttl=120)
 
+    this_info = parse_msg(message, silent=True)
+    reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
+
     res = await fetch_subtitle(vid, youtube_subtitle_provider)
     if error := res.get("error", ""):
-        await modify_progress(text=error, force_update=True, **kwargs)
-        return
-    if not res.get("subtitle", ""):
-        return
+        if "Subtitles are disabled for this video" in error:
+            error = "❌该视频没有提供字幕选项"
+        if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
+            error += "\n🔄尝试使用语音转文字获取字幕"
+            await modify_progress(text=error, force_update=True, **kwargs)
+            msg = message if this_info["mtype"] in ["audio", "video"] else message.reply_to_message
+            fpath: str = await msg.download()  # type: ignore
+            asr_res = await asr_file(fpath, engine="gemini", client=client, message=message, **kwargs)
+            if asr_res.get("error"):
+                await modify_progress(text=asr_res["error"], force_update=True, **kwargs)
+                return
+            res = {"subtitle": asr_res["texts"], "num_chars": len(asr_res["texts"]), "reading_minutes": len(asr_res["texts"]) / READING_SPEED}
+        else:
+            await modify_progress(text=error, force_update=True, **kwargs)
+            return
     subtitles = res.get("subtitle", "")
+    if not subtitles:
+        return
     logger.success(subtitles)
     if vinfo := await fetch_youtube_video_info(vid):
         caption = f"🔴[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['date']:%Y-%m-%d %H:%M:%S}\n"
@@ -67,6 +85,7 @@ async def get_subtitle(client: Client, message: Message, youtube_subtitle_provid
         with io.BytesIO(subtitles.encode("utf-8")) as f:
             await client.send_document(to_int(target_chat), f, file_name=f"{vid}字幕.txt", caption=caption)
 
+    [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
     await modify_progress(del_status=True, **kwargs)
 
 
src/preview/ytdlp.py
@@ -236,7 +236,7 @@ async def preview_ytdlp(
         if subtitles := res.get("subtitle"):
             caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
             with io.BytesIO(subtitles.encode("utf-8")) as f:
-                await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+                await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
                 append_transcription = False  # disable asr transcription
 
     if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
@@ -244,7 +244,8 @@ async def preview_ytdlp(
         if texts := asr_res.get("texts"):
             caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
             with io.BytesIO(texts.encode("utf-8")) as f:
-                await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+                await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
+        [await modify_progress(msg, del_status=True) for msg in asr_res.get("sent_messages", [])]
 
     Path(json_file).unlink(missing_ok=True)
     cleanup_ytdlp(info["id"])