main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import io
  4import re
  5from io import BytesIO
  6
  7from glom import Coalesce, glom
  8from loguru import logger
  9from pyrogram.client import Client
 10from pyrogram.types import Message, ReplyParameters
 11
 12from asr.voice_recognition import asr_file
 13from config import DEVICE_NAME, READING_SPEED, cache
 14from custom.config import ACCOUNT_NAME, GROUP_DEV
 15from messages.parser import parse_msg
 16from messages.progress import modify_progress
 17from messages.sender import send2tg
 18from messages.utils import blockquote
 19from networking import match_social_media_link
 20from preview.bilibili import get_bilibili_vinfo
 21from preview.youtube import get_youtube_vinfo
 22from subtitles.base import fetch_subtitle
 23from summarize.summarize import summarize
 24from utils import count_subtitles, readable_time
 25from ytdlp.download import ytdlp_download
 26
 27
 28async def summary_videos(client: Client, message: Message, retry: int = 0):
 29    if retry > 3:
 30        return
 31    videogram_channel = -1001627687975
 32    if message.chat.id not in [videogram_channel, GROUP_DEV]:
 33        return
 34    # VPS上的账号不响应 开发Group 的消息
 35    if message.chat.id == GROUP_DEV and DEVICE_NAME in ["BennyBot-JP", "BennyBot-US", "BennyBot-CN"]:
 36        return
 37    this_cid = message.chat.id
 38    this_mid = message.id
 39    if ACCOUNT_NAME != "bot":
 40        return
 41    info = parse_msg(message, silent=True, use_cache=False)
 42    text = " ".join(info["entity_urls"]) + " " + info["text"]
 43    matched = await match_social_media_link(text)
 44    if matched["platform"] not in ["bilibili", "youtube"]:
 45        return
 46    url = matched["url"]
 47    vid = matched.get("vid", matched.get("bvid", url))
 48    if cache.get(f"summary_videos-{message.chat.id}-{vid}-{retry}"):
 49        return
 50    cache.set(f"summary_videos-{message.chat.id}-{vid}-{retry}", "1", ttl=120)
 51    vinfo = await get_youtube_vinfo(vid) if matched["platform"] == "youtube" else await get_bilibili_vinfo(vid)
 52    description = vinfo.get("description", vinfo.get("desc", ""))
 53    logger.debug(f"收到{info['mtype']}信息, 尝试从API下载字幕: [{vinfo['title']}]({url})")
 54    reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
 55    subtitle_msg = message
 56    if info["mime_type"] == "text/plain":  # 收到的本身就是字幕文件了
 57        data: BytesIO = await client.download_media(message, in_memory=True)  # type: ignore
 58        subtitles = data.getvalue().decode("utf-8")
 59        res = {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
 60        status = await client.send_message(this_cid, f"✅已下载字幕\n📝{vinfo['title']}\n#️⃣字符数: {res['num_chars']}", reply_parameters=ReplyParameters(message_id=this_mid))
 61    else:  # 该消息并非字幕消息, 首先尝试直接通过API下载
 62        res = await fetch_subtitle(url, reference)
 63        asr_engine = "tencent"
 64        if matched["platform"] == "youtube":  # bypass censorship
 65            asr_engine = "uncensored"
 66        if subtitles := glom(res, Coalesce("full", "subtitles"), default=""):  # API成功获取字幕
 67            status = await client.send_message(this_cid, "✅通过API成功获取字幕", reply_parameters=ReplyParameters(message_id=this_mid))
 68        elif info["mtype"] == "audio":  # 直接下载音频后ASR
 69            logger.warning(f"API下载字幕失败, 直接下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
 70            status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 直接下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
 71            fpath: str = await client.download_media(message)  # type: ignore
 72            prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
 73            res.pop("error", None)
 74            res = await asr_file(fpath, prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
 75            if res.get("error") or len(res.get("texts", "")) == 0:
 76                await modify_progress(message=status, del_status=True)
 77                # await summary_videos(client, message, retry + 1)
 78                return
 79            subtitles = res.get("texts", "")
 80            res |= {"num_chars": count_subtitles(subtitles), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
 81        else:  # 失败了, 使用ytdlp
 82            logger.warning(f"API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕: [{vinfo['title']}]({url})")
 83            status = await client.send_message(this_cid, f"⚠️API下载字幕失败, 通过yt-dlp下载音频后ASR获取字幕\n📝{vinfo['title']}", reply_parameters=ReplyParameters(message_id=this_mid))
 84            downloaded = await ytdlp_download(url, matched["platform"], ytdlp_download_video=False)
 85            if not downloaded["audio_path"].is_file():
 86                await modify_progress(message=status, text="❌下载音频失败", force_update=True)
 87                return
 88            prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
 89            res.pop("error", None)
 90            res = await asr_file(downloaded["audio_path"], prompt=prompt, engine=asr_engine, message=message, corrector_reference=reference, silent=True)
 91            if res.get("error") or len(res.get("texts", "")) == 0:
 92                await modify_progress(message=status, del_status=True)
 93                # await summary_videos(client, message, retry + 1)
 94                return
 95            subtitles = res.get("texts", "")
 96            res |= {"num_chars": count_subtitles(res["texts"]), "reading_minutes": count_subtitles(subtitles) / READING_SPEED}
 97
 98        if count_subtitles(subtitles) < 30:
 99            await modify_progress(message=status, del_status=True)
100            return
101
102        caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n"
103        caption += f"📝[{vinfo['title']}]({url})\n#️⃣字符数: {res['num_chars']}\n⏳阅读时长: {readable_time(60 * res['reading_minutes'])}"
104        # html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
105        # if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=url, ttl="30d"):
106        #     caption += f"\n⚡️[即时预览]({telegraph_url})"
107        with io.BytesIO(subtitles.encode("utf-8")) as f:
108            subtitle_msg = await client.send_document(info["cid"], f, file_name=f"{vinfo['title']}.txt", caption=caption, reply_parameters=ReplyParameters(message_id=info["mid"]))
109        if not isinstance(subtitle_msg, Message):
110            subtitle_msg = message
111
112    subtitles = re.sub(r"(.*?)AI总结(B站版):", "", subtitles, flags=re.DOTALL).strip()  # noqa: RUF001
113    prompt = f"该转录稿对应于{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目,节目详情如下:\n"
114    prompt += f"节目标题: {vinfo['title']}\n发布日期: {vinfo['pubdate']}\n"
115    if description.strip():
116        prompt += f"节目简介: {description}"
117    await modify_progress(message=status, text="🤖字幕总结中...", force_update=True)
118    response = await summarize(sources=[{"type": "system_prompt", "text": prompt}, {"type": "transcripts", "text": subtitles}], model="general", min_text_length=200)
119    if texts := response.get("texts"):
120        await send2tg(client, subtitle_msg, texts=f"{response['prefix']}{blockquote(texts)}")
121    await modify_progress(message=status, del_status=True)
122
123
124async def get_last_messages(client: Client, chat_id: int, start_id: int = 1, window_size: int = 10) -> list[Message]:
125    """二分法查找chat中最后的消息."""
126    last_mid = cache.get(f"lastmid-{chat_id}", start_id)
127    low = float("inf")
128    high = max(start_id, int(last_mid))
129    window_size = min(window_size, 200)
130    while True:
131        logger.trace(f"Retrieval message of {chat_id=}, mid=[{high}, {high + window_size}]")
132        messages: list[Message] = await client.get_messages(chat_id, message_ids=range(high, high + window_size))  # type: ignore
133        if all(m.empty for m in messages):
134            break
135        low = [m.id for m in messages if not m.empty][-1]
136        high = low * 2
137
138    while low <= high:
139        mid = (low + high) // 2
140        logger.trace(f"二分查找, chat: {chat_id}: {low=}, {mid=}, {high=}")
141        messages = await client.get_messages(chat_id, message_ids=range(mid, mid + window_size))  # type: ignore
142        if all(m.empty for m in messages):
143            high = mid - 1
144        else:
145            low = [m.id for m in messages if not m.empty][-1] + 1
146    last_id = min(low, high)
147
148    messages = await client.get_messages(chat_id, message_ids=range(last_id - window_size, last_id + 1))  # type: ignore
149    valid_msgs = [m for m in messages if not m.empty]
150    cache.set(f"lastmid-{chat_id}", valid_msgs[-1].id, ttl=0)
151    return valid_msgs