main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3from io import BytesIO
4
5from glom import Coalesce, glom
6from loguru import logger
7from pyrogram.client import Client
8from pyrogram.types import InputMediaDocument, LinkPreviewOptions, Message
9
10from asr.voice_recognition import asr_file
11from config import AI, ASR, DOWNLOAD_DIR, PREFIX, READING_SPEED, TEXT_LENGTH, cache
12from messages.parser import parse_msg
13from messages.progress import modify_progress
14from messages.sender import send2tg
15from messages.utils import blockquote, count_without_entities, equal_prefix
16from networking import match_social_media_link
17from preview.bilibili import get_bilibili_vinfo
18from preview.youtube import get_youtube_vinfo
19from subtitles.base import fetch_subtitle, match_url
20from summarize.summarize import summarize
21from utils import count_subtitles, readable_time
22from ytdlp.download import ytdlp_download
23
24HELP = f"""📃**提取字幕**
25使用说明:
261. `{PREFIX.SUBTITLE} URL` 下载该链接的字幕
272. 以 `{PREFIX.SUBTITLE}` 回复消息可下载消息中链接的字幕
28
29⚙️站点支持
30🅱️Bilibili
31🔴YouTube
32
33ℹ️方式
34首先尝试下载内嵌字幕, 失败后使用语音转文字获取字幕
35""" # noqa: RUF001
36
37
38async def get_subtitle(
39 client: Client,
40 message: Message,
41 *,
42 ai_summary: bool = True,
43 summary_subtitle_model: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
44 enable_corrector: bool = False,
45 **kwargs,
46):
47 """Get YouTube Subtitle."""
48 # send docs if message == "/subtitle", without reply
49 if equal_prefix(message.text, prefix=[PREFIX.SUBTITLE]) and not message.reply_to_message:
50 await send2tg(client, message, texts=HELP, **kwargs)
51 return
52 if not (url := await match_url(client, message)):
53 return
54 # cache media_group message
55 if media_group_id := message.media_group_id:
56 if cache.get(f"subtitle-{message.chat.id}-{media_group_id}"):
57 return
58 cache.set(f"subtitle-{message.chat.id}-{media_group_id}", "1", ttl=120)
59 matched = await match_social_media_link(url)
60 platform = matched["platform"]
61 if platform not in ["bilibili", "youtube"]:
62 await send2tg(client, message, texts="仅支持Bilibili和YouTube视频链接", **kwargs)
63 return
64 vid = glom(matched, Coalesce("vid", "bvid"), default=url)
65 vinfo = await get_youtube_vinfo(vid) if platform == "youtube" else await get_bilibili_vinfo(vid)
66 if error := vinfo.get("error_msg"):
67 await send2tg(client, message, texts=error, **kwargs)
68 return
69 url = glom(vinfo, Coalesce("url", "link"), default=url)
70 description = glom(vinfo, Coalesce("description", "desc"), default="")
71 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n📝[{vinfo['title']}]({url})"
72 msg = f"🔍**正在获取字幕:**\n{caption}"[:TEXT_LENGTH]
73 status_msg: Message = (await send2tg(client, message, texts=msg, **kwargs))[0]
74 kwargs["progress"] = status_msg
75
76 this_info = parse_msg(message, silent=True)
77 reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
78 # Fetch subtitle via API
79 reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
80 res = await fetch_subtitle(url, reference, enable_corrector=enable_corrector)
81 if error := res.get("error", ""): # API failed
82 asr_engine = ASR.DEFAULT_ENGINE
83 if platform == "youtube": # bypass censorship
84 asr_engine = kwargs.get("asr_engine", "uncensored")
85 if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
86 await modify_progress(text=error + "\n正在通过ASR识别字幕", force_update=True, **kwargs)
87 msg = message if this_info["mtype"] in ["audio", "video"] else message.reply_to_message
88 media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
89 fpath: str = await client.download_media(msg, media_path) # type: ignore
90 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
91 res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, enable_corrector=enable_corrector, silent=True, **kwargs)
92 if res.get("error"):
93 await modify_progress(text=res["error"], force_update=True, **kwargs)
94 return
95 res |= {"subtitles": res["texts"], "num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(res["texts"]) / READING_SPEED}
96 else:
97 await modify_progress(text=error + "\n正在通过下载音频后ASR识别字幕", force_update=True, **kwargs)
98 kwargs |= {"ytdlp_download_video": False}
99 downloaded = await ytdlp_download(url, platform, **kwargs)
100
101 if not downloaded["audio_path"].is_file():
102 await modify_progress(text="❌下载音频失败", force_update=True, **kwargs)
103 return
104 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
105 res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, enable_corrector=enable_corrector, silent=True, **kwargs)
106 if res.get("error"):
107 await modify_progress(text=res["error"], force_update=True, **kwargs)
108 return
109 res |= {"subtitles": res["texts"], "num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(res["texts"]) / READING_SPEED}
110
111 # Send subtitle
112 subtitles = res.get("subtitles", "")
113 if not subtitles:
114 await modify_progress(del_status=True, **kwargs)
115 return
116
117 logger.success(subtitles)
118 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
119 caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
120 full = glom(res, Coalesce("full", "subtitles", "summary"), default="")
121 # Send subtitle txt
122 with BytesIO(full.encode("utf-8")) as f:
123 status_msg = await status_msg.edit_media(media=InputMediaDocument(f, caption=caption))
124
125 if ai_summary and isinstance(status_msg, Message):
126 # use real subtitle (without AI summary by Bilibili)
127 prompt = f"该转录稿对应于{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目,节目详情:\n标题: {vinfo['title']}\n日期: {vinfo['pubdate']}\n"
128 prompt += f"标题: {vinfo['title']}\n日期: {vinfo['pubdate']}\n"
129 if description.strip():
130 prompt += f"节目简介: {description}"
131 summary = await summarize(
132 sources=[{"type": "system_prompt", "text": prompt}, {"type": "transcripts", "text": subtitles}],
133 model=summary_subtitle_model,
134 title=vinfo["title"],
135 author=vinfo["author"],
136 url=url,
137 date=vinfo["pubdate"],
138 description=description,
139 min_text_length=200,
140 force_r2_page=kwargs.get("force_r2_page", False),
141 )
142
143 if not summary.get("texts"):
144 return
145 telegraph_url = summary.get("telegraph_url") or ""
146 link_preview = LinkPreviewOptions(is_disabled=False, show_above_text=True, url=telegraph_url) if telegraph_url else LinkPreviewOptions(is_disabled=True)
147 if await count_without_entities(summary["texts"]) <= TEXT_LENGTH:
148 await status_msg.reply_text(blockquote(summary["texts"]), quote=True, link_preview_options=link_preview)
149 elif telegraph_url:
150 await status_msg.reply_text(telegraph_url, link_preview_options=link_preview, quote=True)
151 else:
152 await send2tg(client, status_msg, texts=summary["texts"])