Commit 29b0048

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-05-22 09:56:01
chore(corrector): add enable_corrector flag to subtitle fetching and ASR functions
1 parent a3fedeb
Changed files (3)
src
src/subtitles/subtitle.py
@@ -47,6 +47,7 @@ async def get_subtitle(
     ai_summary: bool = True,
     summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
     send_subtitle_as: Literal["file", "str", "none"] = "file",
+    enable_corrector: bool = True,
     **kwargs,
 ):
     """Get YouTube Subtitle."""
@@ -84,7 +85,7 @@ async def get_subtitle(
     reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
     # Fetch subtitle via API
     reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
-    res = await fetch_subtitle(url, reference)
+    res = await fetch_subtitle(url, reference, enable_corrector=enable_corrector)
     if error := res.get("error", ""):  # API failed
         asr_engine = ASR.DEFAULT_ENGINE
         if platform == "youtube":  # bypass censorship
@@ -95,7 +96,7 @@ async def get_subtitle(
             media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
             fpath: str = await client.download_media(msg, media_path)  # type: ignore
             prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
-            res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, silent=True, **kwargs)
+            res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, enable_corrector=enable_corrector, silent=True, **kwargs)
             if res.get("error"):
                 await modify_progress(text=res["error"], force_update=True, **kwargs)
                 return
@@ -107,7 +108,7 @@ async def get_subtitle(
                 await modify_progress(text="❌下载音频失败", force_update=True, **kwargs)
                 return
             prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
-            res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
+            res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, enable_corrector=enable_corrector, silent=True, **kwargs)
             if res.get("error"):
                 await modify_progress(text=res["error"], force_update=True, **kwargs)
                 return
src/ytdlp/main.py
@@ -50,6 +50,7 @@ async def preview_ytdlp(
     ytdlp_send_summary: bool = False,
     summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
     to_telegraph: bool = True,
+    enable_corrector: bool = True,
     show_author: bool = True,
     show_title: bool = True,
     show_pubdate: bool = True,
@@ -133,7 +134,7 @@ async def preview_ytdlp(
     if true(ytdlp_send_subtitle) or true(ytdlp_send_summary):
         fpath = info["audio_path"] if info["audio_path"].is_file() else info["video_path"]
         asr_engine = kwargs.get("asr_engine", "uncensored") if platform == "youtube" else ASR.DEFAULT_ENGINE
-        if sub := await get_subtitles(fpath, url, asr_engine, info):
+        if sub := await get_subtitles(fpath, url, asr_engine, info, enable_corrector=enable_corrector):
             subtitles = f"🔤<b>字幕:</b>\n{sub}"
 
     # get ai summary
src/ytdlp/utils.py
@@ -190,16 +190,16 @@ def find_thumbnail(video_path: str | Path, audio_path: str | Path) -> str | None
     return None
 
 
-async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict) -> str:
+async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict, *, enable_corrector: bool = True) -> str:
     # send subtitles
     subtitles = ""
     matched = await match_social_media_link(url)
     reference = generate_prompt(vinfo)
     if matched["platform"] in ["bilibili", "youtube"]:  # get subtitle from API first
-        res = await fetch_subtitle(url=url, reference=reference)
+        res = await fetch_subtitle(url=url, reference=reference, enable_corrector=enable_corrector)
         subtitles = res.get("subtitles", "")  # only subtitles, no Bilibili's AI summary
     if not subtitles:
-        res = await asr_file(audio_path, asr_engine, corrector_reference=reference, silent=True)
+        res = await asr_file(audio_path, asr_engine, corrector_reference=reference, enable_corrector=enable_corrector, silent=True)
         subtitles = res.get("texts", "")
         if count_subtitles(subtitles) < 20:
             subtitles = ""  # ignore too  short transcription