main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import io
  4import re
  5from io import BytesIO
  6
  7from glom import Coalesce, glom
  8from loguru import logger
  9from pyrogram.client import Client
 10from pyrogram.types import Message, ReplyParameters
 11from pyrogram.types.messages_and_media.message import Str
 12
 13from ai.main import ai_text_generation
 14from asr.voice_recognition import asr_file
 15from config import DEVICE_NAME, PREFIX, READING_SPEED, cache
 16from custom.config import ACCOUNT_NAME, GROUP_DEV
 17from database.memory import del_memory_cache
 18from messages.parser import parse_msg
 19from messages.progress import modify_progress
 20from messages.sender import send2tg
 21from messages.utils import blockquote
 22from networking import match_social_media_link
 23from preview.bilibili import get_bilibili_vinfo
 24from preview.youtube import get_youtube_vinfo
 25from publish import publish_telegraph
 26from subtitles.base import fetch_subtitle
 27from utils import count_subtitles, rand_number, readable_time
 28from ytdlp.download import ytdlp_download
 29
 30
 31async def summary_videos(client: Client, message: Message, retry: int = 0):
 32    if retry > 3:
 33        return
 34    videogram_channel = -1001627687975
 35    if message.chat.id not in [videogram_channel, GROUP_DEV]:
 36        return
 37    # VPS上的账号不响应 开发Group 的消息
 38    if message.chat.id == GROUP_DEV and DEVICE_NAME in ["BennyBot-JP", "BennyBot-US", "BennyBot-CN"]:
 39        return
 40    this_cid = message.chat.id
 41    this_mid = message.id
 42    if ACCOUNT_NAME != "bot":
 43        return
 44    info = parse_msg(message, silent=True)
 45    text = " ".join(info["entity_urls"]) + " " + info["text"]
 46    matched = await match_social_media_link(text)
 47    if matched["platform"] not in ["bilibili", "youtube"]:
 48        return
 49    del_memory_cache(f"parse_msg-{message.chat.id}-{message.id}")
 50    url = matched["url"]
 51    vid = matched.get("vid", matched.get("bvid", url))
 52    if cache.get(f"summary_videos-{message.chat.id}-{vid}-{retry}"):
 53        return
 54    cache.set(f"summary_videos-{message.chat.id}-{vid}-{retry}", "1", ttl=120)
 55    vinfo = await get_youtube_vinfo(vid) if matched["platform"] == "youtube" else await get_bilibili_vinfo(vid)
 56    description = vinfo.get("description", vinfo.get("desc", ""))
 57    logger.debug(f"收到{info['mtype']}信息, 尝试从API下载字幕: [{vinfo['title']}]({url})")
 58    reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
 59    subtitle_msg = message
 60    if info["mime_type"] == "text/plain":  # 收到的本身就是字幕文件了
 61        data: BytesIO = await client.download_media(message, in_memory=True)  # type: ignore
 62        subtitles = data.getvalue().decode("utf-8")
 63        res = {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
 64        status = await client.send_message(this_cid, f"✅已下载字幕\n📝{vinfo['title']}\n#️⃣字符数: {res['num_chars']}", reply_parameters=ReplyParameters(message_id=this_mid))
 65    else:  # 该消息并非字幕消息, 首先尝试直接通过API下载
 66        res = await fetch_subtitle(url, reference)
 67        asr_engine = "tencent"
 68        if matched["platform"] == "youtube":  # bypass censorship
 69            asr_engine = "uncensored"
 70        if subtitles := glom(res, Coalesce("full", "subtitles"), default=""):  # API成功获取字幕
 71            status = await client.send_message(this_cid, "✅通过API成功获取字幕", reply_parameters=ReplyParameters(message_id=this_mid))
 72        elif info["mtype"] == "audio":  # 直接下载音频后ASR
 73            logger.warning(f"API下载字幕失败, 直接下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
 74            status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 直接下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
 75            fpath: str = await client.download_media(message)  # type: ignore
 76            prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
 77            res.pop("error", None)
 78            res = await asr_file(fpath, prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
 79            if res.get("error") or len(res.get("texts", "")) == 0:
 80                await modify_progress(message=status, del_status=True)
 81                # await summary_videos(client, message, retry + 1)
 82                return
 83            subtitles = res.get("texts", "")
 84            res |= {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
 85        else:  # 失败了, 使用ytdlp
 86            logger.warning(f"API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
 87            status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
 88            downloaded = await ytdlp_download(url, matched["platform"], ytdlp_download_video=False)
 89            if not downloaded["audio_path"].is_file():
 90                await modify_progress(message=status, text="❌下载音频失败", force_update=True)
 91                return
 92            prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
 93            res.pop("error", None)
 94            res = await asr_file(downloaded["audio_path"], prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
 95            if res.get("error") or len(res.get("texts", "")) == 0:
 96                await modify_progress(message=status, del_status=True)
 97                # await summary_videos(client, message, retry + 1)
 98                return
 99            subtitles = res.get("texts", "")
100            res |= {"num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
101
102        if count_subtitles(subtitles) < 30:
103            await modify_progress(message=status, del_status=True)
104            return
105
106        caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
107        caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
108        html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
109        if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=url):
110            caption += f"\n⚡️[即时预览]({telegraph_url})"
111        with io.BytesIO(subtitles.encode("utf-8")) as f:
112            subtitle_msg = await client.send_document(info["cid"], f, file_name=f"{vinfo['title']}.txt", caption=caption, reply_parameters=ReplyParameters(message_id=info["mid"]))
113        if not isinstance(subtitle_msg, Message):
114            subtitle_msg = message
115
116    subtitles = re.sub(r"(.*?)AI总结(B站版):", "", subtitles, flags=re.DOTALL).strip()  # noqa: RUF001
117    prompt = f"以上是{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的文字稿。该期节目详情如下:\n"
118    prompt += f"节目标题: {vinfo['title']}\n发布日期: {vinfo['pubdate']}\n"
119    if description.strip():
120        prompt += f"节目简介: {description}\n"
121    prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
122    # Construct a message to call GPT
123    ai_msg = Message(
124        id=subtitle_msg.id,
125        chat=subtitle_msg.chat,
126        text=Str(f"{PREFIX.AI_TEXT_GENERATION} {prompt}"),
127        reply_to_message=Message(id=rand_number(), chat=subtitle_msg.chat, text=Str(subtitles)),
128    )
129    await modify_progress(message=status, text="🤖字幕总结中...", force_update=True)
130    response = await ai_text_generation(client, ai_msg, silent=True)
131    if texts := response.get("texts"):
132        await send2tg(client, subtitle_msg, texts=f"{response['prefix']}{blockquote(texts)}")
133    await modify_progress(message=status, del_status=True)
134
135
136async def get_last_messages(client: Client, chat_id: int, start_id: int = 1, window_size: int = 10) -> list[Message]:
137    """二分法查找chat中最后的消息."""
138    last_mid = cache.get(f"lastmid-{chat_id}", start_id)
139    low = float("inf")
140    high = max(start_id, int(last_mid))
141    window_size = min(window_size, 200)
142    while True:
143        logger.trace(f"Retrieval message of {chat_id=}, mid=[{high}, {high + window_size}]")
144        messages: list[Message] = await client.get_messages(chat_id, message_ids=range(high, high + window_size))  # type: ignore
145        if all(m.empty for m in messages):
146            break
147        low = [m.id for m in messages if not m.empty][-1]
148        high = low * 2
149
150    while low <= high:
151        mid = (low + high) // 2
152        logger.trace(f"二分查找, chat: {chat_id}: {low=}, {mid=}, {high=}")
153        messages = await client.get_messages(chat_id, message_ids=range(mid, mid + window_size))  # type: ignore
154        if all(m.empty for m in messages):
155            high = mid - 1
156        else:
157            low = [m.id for m in messages if not m.empty][-1] + 1
158    last_id = min(low, high)
159
160    messages = await client.get_messages(chat_id, message_ids=range(last_id - window_size, last_id + 1))  # type: ignore
161    valid_msgs = [m for m in messages if not m.empty]
162    cache.set(f"lastmid-{chat_id}", valid_msgs[-1].id, ttl=0)
163    return valid_msgs