Commit 29b0048
Changed files (3)
src
subtitles
src/subtitles/subtitle.py
@@ -47,6 +47,7 @@ async def get_subtitle(
ai_summary: bool = True,
summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
send_subtitle_as: Literal["file", "str", "none"] = "file",
+ enable_corrector: bool = True,
**kwargs,
):
"""Get YouTube Subtitle."""
@@ -84,7 +85,7 @@ async def get_subtitle(
reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
# Fetch subtitle via API
reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
- res = await fetch_subtitle(url, reference)
+ res = await fetch_subtitle(url, reference, enable_corrector=enable_corrector)
if error := res.get("error", ""): # API failed
asr_engine = ASR.DEFAULT_ENGINE
if platform == "youtube": # bypass censorship
@@ -95,7 +96,7 @@ async def get_subtitle(
media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
fpath: str = await client.download_media(msg, media_path) # type: ignore
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, silent=True, **kwargs)
+ res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, enable_corrector=enable_corrector, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return
@@ -107,7 +108,7 @@ async def get_subtitle(
await modify_progress(text="❌下载音频失败", force_update=True, **kwargs)
return
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
+ res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, enable_corrector=enable_corrector, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return
src/ytdlp/main.py
@@ -50,6 +50,7 @@ async def preview_ytdlp(
ytdlp_send_summary: bool = False,
summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
to_telegraph: bool = True,
+ enable_corrector: bool = True,
show_author: bool = True,
show_title: bool = True,
show_pubdate: bool = True,
@@ -133,7 +134,7 @@ async def preview_ytdlp(
if true(ytdlp_send_subtitle) or true(ytdlp_send_summary):
fpath = info["audio_path"] if info["audio_path"].is_file() else info["video_path"]
asr_engine = kwargs.get("asr_engine", "uncensored") if platform == "youtube" else ASR.DEFAULT_ENGINE
- if sub := await get_subtitles(fpath, url, asr_engine, info):
+ if sub := await get_subtitles(fpath, url, asr_engine, info, enable_corrector=enable_corrector):
subtitles = f"🔤<b>字幕:</b>\n{sub}"
# get ai summary
src/ytdlp/utils.py
@@ -190,16 +190,16 @@ def find_thumbnail(video_path: str | Path, audio_path: str | Path) -> str | None
return None
-async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict) -> str:
+async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict, *, enable_corrector: bool = True) -> str:
# send subtitles
subtitles = ""
matched = await match_social_media_link(url)
reference = generate_prompt(vinfo)
if matched["platform"] in ["bilibili", "youtube"]: # get subtitle from API first
- res = await fetch_subtitle(url=url, reference=reference)
+ res = await fetch_subtitle(url=url, reference=reference, enable_corrector=enable_corrector)
subtitles = res.get("subtitles", "") # only subtitles, no Bilibili's AI summary
if not subtitles:
- res = await asr_file(audio_path, asr_engine, corrector_reference=reference, silent=True)
+ res = await asr_file(audio_path, asr_engine, corrector_reference=reference, enable_corrector=enable_corrector, silent=True)
subtitles = res.get("texts", "")
if count_subtitles(subtitles) < 20:
subtitles = "" # ignore too short transcription