main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import re
5from io import BytesIO
6from typing import Literal
7
8from glom import Coalesce, glom
9from loguru import logger
10from pyrogram.client import Client
11from pyrogram.types import Message
12from pyrogram.types.messages_and_media.message import Str
13
14from ai.main import ai_text_generation
15from asr.voice_recognition import asr_file
16from config import AI, ASR, DOWNLOAD_DIR, PREFIX, READING_SPEED, TEXT_LENGTH, cache
17from messages.parser import parse_msg
18from messages.progress import modify_progress
19from messages.sender import send2tg
20from messages.utils import blockquote, delete_message, equal_prefix
21from networking import match_social_media_link
22from preview.bilibili import get_bilibili_vinfo
23from preview.youtube import get_youtube_vinfo
24from publish import publish_telegraph
25from subtitles.base import fetch_subtitle, match_url
26from utils import count_subtitles, rand_number, readable_time, to_int
27from ytdlp.download import ytdlp_download
28
29HELP = f"""📃**提取字幕**
30使用说明:
311. `{PREFIX.SUBTITLE} URL` 下载该链接的字幕
322. 以 `{PREFIX.SUBTITLE}` 回复消息可下载消息中链接的字幕
33
34⚙️站点支持
35🅱️Bilibili
36🔴YouTube
37
38ℹ️方式
39首先尝试下载内嵌字幕, 失败后使用语音转文字获取字幕
40""" # noqa: RUF001
41
42
43async def get_subtitle(
44 client: Client,
45 message: Message,
46 *,
47 to_telegraph: bool = True,
48 ai_summary: bool = True,
49 summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
50 send_subtitle_as: Literal["file", "str", "none"] = "file",
51 **kwargs,
52):
53 """Get YouTube Subtitle."""
54 target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
55 # send docs if message == "/subtitle", without reply
56 if equal_prefix(message.text, prefix=[PREFIX.SUBTITLE]) and not message.reply_to_message:
57 await send2tg(client, message, texts=HELP, **kwargs)
58 return
59 if not (url := await match_url(client, message)):
60 return
61 # cache media_group message
62 if media_group_id := message.media_group_id:
63 if cache.get(f"subtitle-{message.chat.id}-{media_group_id}"):
64 return
65 cache.set(f"subtitle-{message.chat.id}-{media_group_id}", "1", ttl=120)
66 matched = await match_social_media_link(url)
67 platform = matched["platform"]
68 if platform not in ["bilibili", "youtube"]:
69 await send2tg(client, message, texts="仅支持Bilibili和YouTube视频链接", **kwargs)
70 return
71 vid = glom(matched, Coalesce("vid", "bvid"), default=url)
72 vinfo = await get_youtube_vinfo(vid) if platform == "youtube" else await get_bilibili_vinfo(vid)
73 if error := vinfo.get("error_msg"):
74 await send2tg(client, message, texts=error, **kwargs)
75 return
76 url = glom(vinfo, Coalesce("url", "link"), default=url)
77 description = glom(vinfo, Coalesce("description", "desc"), default="")
78 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n📝[{vinfo['title']}]({url})"
79 msg = f"🔍**正在获取字幕:**\n{caption}"[:TEXT_LENGTH]
80 if kwargs.get("show_progress"):
81 status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
82 kwargs["progress"] = status_msg
83
84 this_info = parse_msg(message, silent=True)
85 reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
86 # Fetch subtitle via API
87 reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
88 res = await fetch_subtitle(url, reference)
89 if error := res.get("error", ""): # API failed
90 asr_engine = ASR.DEFAULT_ENGINE
91 if platform == "youtube": # bypass censorship
92 asr_engine = kwargs.get("asr_engine", "uncensored")
93 if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
94 await modify_progress(text=error + "\n正在通过ASR识别字幕", force_update=True, **kwargs)
95 msg = message if this_info["mtype"] in ["audio", "video"] else message.reply_to_message
96 media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
97 fpath: str = await client.download_media(msg, media_path) # type: ignore
98 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
99 res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, silent=True, **kwargs)
100 if res.get("error"):
101 await modify_progress(text=res["error"], force_update=True, **kwargs)
102 return
103 res |= {"subtitles": res["texts"], "num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(res["texts"]) / READING_SPEED}
104 else:
105 await modify_progress(text=error + "\n正在通过下载音频后ASR识别字幕", force_update=True, **kwargs)
106 downloaded = await ytdlp_download(url, platform, ytdlp_download_video=False)
107 if not downloaded["audio_path"].is_file():
108 await modify_progress(text="❌下载音频失败", force_update=True, **kwargs)
109 return
110 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
111 res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
112 if res.get("error"):
113 await modify_progress(text=res["error"], force_update=True, **kwargs)
114 return
115 res |= {"subtitles": res["texts"], "num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(res["texts"]) / READING_SPEED}
116
117 # Send subtitle
118 subtitles = glom(res, Coalesce("full", "subtitles", "summary"), default="")
119 if not subtitles:
120 await modify_progress(del_status=True, **kwargs)
121 return
122 logger.success(subtitles)
123 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
124 caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
125 if send_subtitle_as == "file":
126 if to_telegraph:
127 html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
128 if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=url):
129 caption += f"\n⚡️[即时预览]({telegraph_url})"
130 with BytesIO(subtitles.encode("utf-8")) as f:
131 subtitle_msg = await client.send_document(to_int(target_chat), f, file_name=f"{vinfo['title']}.txt", caption=caption)
132 elif send_subtitle_as == "str":
133 subtitle_msg = (await send2tg(client, message, texts=f"{caption}\n{subtitles}", **kwargs))[0]
134 else:
135 subtitle_msg = message
136 if ai_summary and isinstance(subtitle_msg, Message):
137 # use real subtitle (without AI summary by Bilibili)
138 subtitles = re.sub(r"(.*?)AI总结(B站版):", "", subtitles, flags=re.DOTALL).strip() # noqa: RUF001
139 prompt = f"以上是{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的文字稿。该期节目详情如下:\n"
140 prompt += f"节目标题: {vinfo['title']}\n发布日期: {vinfo['pubdate']}\n"
141 if description.strip():
142 prompt += f"节目简介: {description}\n"
143 prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
144 ai_msg = Message( # Construct a message for AI
145 id=subtitle_msg.id,
146 chat=subtitle_msg.chat,
147 text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{summary_model_id} {prompt}"),
148 reply_to_message=Message(id=rand_number(), chat=subtitle_msg.chat, text=Str(subtitles)),
149 )
150 res = await ai_text_generation(client, ai_msg, silent=True)
151 if res.get("texts"):
152 await send2tg(client, ai_msg, texts=res["prefix"] + blockquote(res["texts"]), **kwargs)
153 with contextlib.suppress(Exception):
154 [await delete_message(msg) for msg in res.get("sent_messages", [])]
155 await delete_message(kwargs.get("progress"))