main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import io
4import re
5from io import BytesIO
6
7from glom import Coalesce, glom
8from loguru import logger
9from pyrogram.client import Client
10from pyrogram.types import Message, ReplyParameters
11from pyrogram.types.messages_and_media.message import Str
12
13from ai.main import ai_text_generation
14from asr.voice_recognition import asr_file
15from config import DEVICE_NAME, PREFIX, READING_SPEED, cache
16from custom.config import ACCOUNT_NAME, GROUP_DEV
17from database.memory import del_memory_cache
18from messages.parser import parse_msg
19from messages.progress import modify_progress
20from messages.sender import send2tg
21from messages.utils import blockquote
22from networking import match_social_media_link
23from preview.bilibili import get_bilibili_vinfo
24from preview.youtube import get_youtube_vinfo
25from publish import publish_telegraph
26from subtitles.base import fetch_subtitle
27from utils import count_subtitles, rand_number, readable_time
28from ytdlp.download import ytdlp_download
29
30
31async def summary_videos(client: Client, message: Message, retry: int = 0):
32 if retry > 3:
33 return
34 videogram_channel = -1001627687975
35 if message.chat.id not in [videogram_channel, GROUP_DEV]:
36 return
37 # VPS上的账号不响应 开发Group 的消息
38 if message.chat.id == GROUP_DEV and DEVICE_NAME in ["BennyBot-JP", "BennyBot-US", "BennyBot-CN"]:
39 return
40 this_cid = message.chat.id
41 this_mid = message.id
42 if ACCOUNT_NAME != "bot":
43 return
44 info = parse_msg(message, silent=True)
45 text = " ".join(info["entity_urls"]) + " " + info["text"]
46 matched = await match_social_media_link(text)
47 if matched["platform"] not in ["bilibili", "youtube"]:
48 return
49 del_memory_cache(f"parse_msg-{message.chat.id}-{message.id}")
50 url = matched["url"]
51 vid = matched.get("vid", matched.get("bvid", url))
52 if cache.get(f"summary_videos-{message.chat.id}-{vid}-{retry}"):
53 return
54 cache.set(f"summary_videos-{message.chat.id}-{vid}-{retry}", "1", ttl=120)
55 vinfo = await get_youtube_vinfo(vid) if matched["platform"] == "youtube" else await get_bilibili_vinfo(vid)
56 description = vinfo.get("description", vinfo.get("desc", ""))
57 logger.debug(f"收到{info['mtype']}信息, 尝试从API下载字幕: [{vinfo['title']}]({url})")
58 reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
59 subtitle_msg = message
60 if info["mime_type"] == "text/plain": # 收到的本身就是字幕文件了
61 data: BytesIO = await client.download_media(message, in_memory=True) # type: ignore
62 subtitles = data.getvalue().decode("utf-8")
63 res = {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
64 status = await client.send_message(this_cid, f"✅已下载字幕\n📝{vinfo['title']}\n#️⃣字符数: {res['num_chars']}", reply_parameters=ReplyParameters(message_id=this_mid))
65 else: # 该消息并非字幕消息, 首先尝试直接通过API下载
66 res = await fetch_subtitle(url, reference)
67 asr_engine = "tencent"
68 if matched["platform"] == "youtube": # bypass censorship
69 asr_engine = "uncensored"
70 if subtitles := glom(res, Coalesce("full", "subtitles"), default=""): # API成功获取字幕
71 status = await client.send_message(this_cid, "✅通过API成功获取字幕", reply_parameters=ReplyParameters(message_id=this_mid))
72 elif info["mtype"] == "audio": # 直接下载音频后ASR
73 logger.warning(f"API下载字幕失败, 直接下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
74 status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 直接下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
75 fpath: str = await client.download_media(message) # type: ignore
76 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
77 res.pop("error", None)
78 res = await asr_file(fpath, prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
79 if res.get("error") or len(res.get("texts", "")) == 0:
80 await modify_progress(message=status, del_status=True)
81 # await summary_videos(client, message, retry + 1)
82 return
83 subtitles = res.get("texts", "")
84 res |= {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
85 else: # 失败了, 使用ytdlp
86 logger.warning(f"API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
87 status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
88 downloaded = await ytdlp_download(url, matched["platform"], ytdlp_download_video=False)
89 if not downloaded["audio_path"].is_file():
90 await modify_progress(message=status, text="❌下载音频失败", force_update=True)
91 return
92 prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
93 res.pop("error", None)
94 res = await asr_file(downloaded["audio_path"], prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
95 if res.get("error") or len(res.get("texts", "")) == 0:
96 await modify_progress(message=status, del_status=True)
97 # await summary_videos(client, message, retry + 1)
98 return
99 subtitles = res.get("texts", "")
100 res |= {"num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
101
102 if count_subtitles(subtitles) < 30:
103 await modify_progress(message=status, del_status=True)
104 return
105
106 caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
107 caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
108 html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
109 if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=url):
110 caption += f"\n⚡️[即时预览]({telegraph_url})"
111 with io.BytesIO(subtitles.encode("utf-8")) as f:
112 subtitle_msg = await client.send_document(info["cid"], f, file_name=f"{vinfo['title']}.txt", caption=caption, reply_parameters=ReplyParameters(message_id=info["mid"]))
113 if not isinstance(subtitle_msg, Message):
114 subtitle_msg = message
115
116 subtitles = re.sub(r"(.*?)AI总结(B站版):", "", subtitles, flags=re.DOTALL).strip() # noqa: RUF001
117 prompt = f"以上是{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的文字稿。该期节目详情如下:\n"
118 prompt += f"节目标题: {vinfo['title']}\n发布日期: {vinfo['pubdate']}\n"
119 if description.strip():
120 prompt += f"节目简介: {description}\n"
121 prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
122 # Construct a message to call GPT
123 ai_msg = Message(
124 id=subtitle_msg.id,
125 chat=subtitle_msg.chat,
126 text=Str(f"{PREFIX.AI_TEXT_GENERATION} {prompt}"),
127 reply_to_message=Message(id=rand_number(), chat=subtitle_msg.chat, text=Str(subtitles)),
128 )
129 await modify_progress(message=status, text="🤖字幕总结中...", force_update=True)
130 response = await ai_text_generation(client, ai_msg, silent=True)
131 if texts := response.get("texts"):
132 await send2tg(client, subtitle_msg, texts=f"{response['prefix']}{blockquote(texts)}")
133 await modify_progress(message=status, del_status=True)
134
135
136async def get_last_messages(client: Client, chat_id: int, start_id: int = 1, window_size: int = 10) -> list[Message]:
137 """二分法查找chat中最后的消息."""
138 last_mid = cache.get(f"lastmid-{chat_id}", start_id)
139 low = float("inf")
140 high = max(start_id, int(last_mid))
141 window_size = min(window_size, 200)
142 while True:
143 logger.trace(f"Retrieval message of {chat_id=}, mid=[{high}, {high + window_size}]")
144 messages: list[Message] = await client.get_messages(chat_id, message_ids=range(high, high + window_size)) # type: ignore
145 if all(m.empty for m in messages):
146 break
147 low = [m.id for m in messages if not m.empty][-1]
148 high = low * 2
149
150 while low <= high:
151 mid = (low + high) // 2
152 logger.trace(f"二分查找, chat: {chat_id}: {low=}, {mid=}, {high=}")
153 messages = await client.get_messages(chat_id, message_ids=range(mid, mid + window_size)) # type: ignore
154 if all(m.empty for m in messages):
155 high = mid - 1
156 else:
157 low = [m.id for m in messages if not m.empty][-1] + 1
158 last_id = min(low, high)
159
160 messages = await client.get_messages(chat_id, message_ids=range(last_id - window_size, last_id + 1)) # type: ignore
161 valid_msgs = [m for m in messages if not m.empty]
162 cache.set(f"lastmid-{chat_id}", valid_msgs[-1].id, ttl=0)
163 return valid_msgs