main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import io
4import re
5from io import BytesIO
6
7from glom import Coalesce, glom
8from loguru import logger
9from pyrogram.client import Client
10from pyrogram.types import Message, ReplyParameters
11
12from asr.voice_recognition import asr_file
13from config import DEVICE_NAME, READING_SPEED, cache
14from custom.config import ACCOUNT_NAME, GROUP_DEV
15from messages.parser import parse_msg
16from messages.progress import modify_progress
17from messages.sender import send2tg
18from messages.utils import blockquote
19from networking import match_social_media_link
20from preview.bilibili import get_bilibili_vinfo
21from preview.youtube import get_youtube_vinfo
22from subtitles.base import fetch_subtitle
23from summarize.summarize import summarize
24from utils import count_subtitles, readable_time
25from ytdlp.download import ytdlp_download
26
27
28async def summary_videos(client: Client, message: Message, retry: int = 0):
29 if retry > 3:
30 return
31 videogram_channel = -1001627687975
32 if message.chat.id not in [videogram_channel, GROUP_DEV]:
33 return
34 # VPS上的账号不响应 开发Group 的消息
35 if message.chat.id == GROUP_DEV and DEVICE_NAME in ["BennyBot-JP", "BennyBot-US", "BennyBot-CN"]:
36 return
37 this_cid = message.chat.id
38 this_mid = message.id
39 if ACCOUNT_NAME != "bot":
40 return
41 info = parse_msg(message, silent=True, use_cache=False)
42 text = " ".join(info["entity_urls"]) + " " + info["text"]
43 matched = await match_social_media_link(text)
44 if matched["platform"] not in ["bilibili", "youtube"]:
45 return
46 url = matched["url"]
47 vid = matched.get("vid", matched.get("bvid", url))
48 if cache.get(f"summary_videos-{message.chat.id}-{vid}-{retry}"):
49 return
50 cache.set(f"summary_videos-{message.chat.id}-{vid}-{retry}", "1", ttl=120)
51 vinfo = await get_youtube_vinfo(vid) if matched["platform"] == "youtube" else await get_bilibili_vinfo(vid)
52 description = vinfo.get("description", vinfo.get("desc", ""))
53 logger.debug(f"收到{info['mtype']}信息, 尝试从API下载字幕: [{vinfo['title']}]({url})")
54 reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
55 subtitle_msg = message
56 if info["mime_type"] == "text/plain": # 收到的本身就是字幕文件了
57 data: BytesIO = await client.download_media(message, in_memory=True) # type: ignore
58 subtitles = data.getvalue().decode("utf-8")
59 res = {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
60 status = await client.send_message(this_cid, f"✅已下载字幕\n📝{vinfo['title']}\n#️⃣字符数: {res['num_chars']}", reply_parameters=ReplyParameters(message_id=this_mid))
61 else: # 该消息并非字幕消息, 首先尝试直接通过API下载
62 res = await fetch_subtitle(url, reference)
63 asr_engine = "tencent"
64 if matched["platform"] == "youtube": # bypass censorship
65 asr_engine = "uncensored"
66 if subtitles := glom(res, Coalesce("full", "subtitles"), default=""): # API成功获取字幕
67 status = await client.send_message(this_cid, "✅通过API成功获取字幕", reply_parameters=ReplyParameters(message_id=this_mid))
68 elif info["mtype"] == "audio": # 直接下载音频后ASR
69 logger.warning(f"API下载字幕失败, 直接下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
70 status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 直接下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
71 fpath: str = await client.download_media(message) # type: ignore
72 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
73 res.pop("error", None)
74 res = await asr_file(fpath, prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
75 if res.get("error") or len(res.get("texts", "")) == 0:
76 await modify_progress(message=status, del_status=True)
77 # await summary_videos(client, message, retry + 1)
78 return
79 subtitles = res.get("texts", "")
80 res |= {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
81 else: # 失败了, 使用ytdlp
82 logger.warning(f"API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
83 status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
84 downloaded = await ytdlp_download(url, matched["platform"], ytdlp_download_video=False)
85 if not downloaded["audio_path"].is_file():
86 await modify_progress(message=status, text="❌下载音频失败", force_update=True)
87 return
88 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
89 res.pop("error", None)
90 res = await asr_file(downloaded["audio_path"], prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
91 if res.get("error") or len(res.get("texts", "")) == 0:
92 await modify_progress(message=status, del_status=True)
93 # await summary_videos(client, message, retry + 1)
94 return
95 subtitles = res.get("texts", "")
96 res |= {"num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
97
98 if count_subtitles(subtitles) < 30:
99 await modify_progress(message=status, del_status=True)
100 return
101
102 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
103 caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
104 # html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
105 # if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=url, ttl="30d"):
106 # caption += f"\n⚡️[即时预览]({telegraph_url})"
107 with io.BytesIO(subtitles.encode("utf-8")) as f:
108 subtitle_msg = await client.send_document(info["cid"], f, file_name=f"{vinfo['title']}.txt", caption=caption, reply_parameters=ReplyParameters(message_id=info["mid"]))
109 if not isinstance(subtitle_msg, Message):
110 subtitle_msg = message
111
112 subtitles = re.sub(r"(.*?)AI总结(B站版):", "", subtitles, flags=re.DOTALL).strip() # noqa: RUF001
113 prompt = f"该转录稿对应于{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目,节目详情如下:\n"
114 prompt += f"节目标题: {vinfo['title']}\n发布日期: {vinfo['pubdate']}\n"
115 if description.strip():
116 prompt += f"节目简介: {description}"
117 await modify_progress(message=status, text="🤖字幕总结中...", force_update=True)
118 response = await summarize(sources=[{"type": "system_prompt", "text": prompt}, {"type": "transcripts", "text": subtitles}], model="general", min_text_length=200)
119 if texts := response.get("texts"):
120 await send2tg(client, subtitle_msg, texts=f"{response['prefix']}{blockquote(texts)}")
121 await modify_progress(message=status, del_status=True)
122
123
124async def get_last_messages(client: Client, chat_id: int, start_id: int = 1, window_size: int = 10) -> list[Message]:
125 """二分法查找chat中最后的消息."""
126 last_mid = cache.get(f"lastmid-{chat_id}", start_id)
127 low = float("inf")
128 high = max(start_id, int(last_mid))
129 window_size = min(window_size, 200)
130 while True:
131 logger.trace(f"Retrieval message of {chat_id=}, mid=[{high}, {high + window_size}]")
132 messages: list[Message] = await client.get_messages(chat_id, message_ids=range(high, high + window_size)) # type: ignore
133 if all(m.empty for m in messages):
134 break
135 low = [m.id for m in messages if not m.empty][-1]
136 high = low * 2
137
138 while low <= high:
139 mid = (low + high) // 2
140 logger.trace(f"二分查找, chat: {chat_id}: {low=}, {mid=}, {high=}")
141 messages = await client.get_messages(chat_id, message_ids=range(mid, mid + window_size)) # type: ignore
142 if all(m.empty for m in messages):
143 high = mid - 1
144 else:
145 low = [m.id for m in messages if not m.empty][-1] + 1
146 last_id = min(low, high)
147
148 messages = await client.get_messages(chat_id, message_ids=range(last_id - window_size, last_id + 1)) # type: ignore
149 valid_msgs = [m for m in messages if not m.empty]
150 cache.set(f"lastmid-{chat_id}", valid_msgs[-1].id, ttl=0)
151 return valid_msgs