main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3from datetime import timedelta
  4
  5from loguru import logger
  6from pyrogram.client import Client
  7from pyrogram.types import Message
  8from youtube_transcript_api import IpBlocked, RequestBlocked, YouTubeTranscriptApi
  9from youtube_transcript_api.proxies import GenericProxyConfig
 10
 11from asr.corrector import asr_corrector
 12from config import PREFIX, PROXY, READING_SPEED, cache
 13from messages.parser import parse_msg
 14from messages.utils import startswith_prefix
 15from networking import match_social_media_link
 16from preview.bilibili import bilibili_subtitle_and_summary
 17from utils import seconds_to_time
 18
 19
 20async def match_url(client: Client, message: Message) -> str:
 21    """Find valid url from message."""
 22    info = parse_msg(message, silent=True)
 23    if not startswith_prefix(info["text"], prefix=[PREFIX.SUBTITLE]):
 24        return ""
 25    # /subtitle "link"
 26    matched = await match_social_media_link(info["text"])
 27    if matched["platform"] in ["youtube", "bilibili"]:
 28        return matched["url"]
 29    for entity_url in info["entity_urls"]:
 30        matched = await match_social_media_link(entity_url)
 31        if matched["platform"] in ["youtube", "bilibili"]:
 32            return matched["url"]
 33
 34    # is replying to message?
 35    if not message.reply_to_message:
 36        return ""
 37    reply_message = message.reply_to_message
 38    # if reply to a media_group, fetch all messages in the group
 39    reply_messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [reply_message]
 40    for msg in reply_messages:
 41        info = parse_msg(msg, silent=True)
 42        matched = await match_social_media_link(info["text"])
 43        if matched["platform"] in ["youtube", "bilibili"]:
 44            return matched["url"]
 45        for entity_url in info["entity_urls"]:
 46            matched = await match_social_media_link(entity_url)
 47            if matched["platform"] in ["youtube", "bilibili"]:
 48                return matched["url"]
 49    return ""
 50
 51
 52@cache.memoize(ttl=120)
 53async def fetch_subtitle(url: str, reference: str = "") -> dict:
 54    """Fetch subtitles from Bilibili or YouTube.
 55
 56    Returns:
 57        dict: {
 58            "subtitles": "[minute:second] texts",
 59            "num_chars": len(texts),
 60            "reading_minutes": 2,
 61            }
 62    """
 63    subtitles = []
 64    matched = await match_social_media_link(url)
 65    if matched["platform"] == "bilibili":
 66        resp = await bilibili_subtitle_and_summary(url)
 67        if resp.get("subtitles"):
 68            resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
 69        return resp
 70
 71    video_id = matched["vid"]
 72    try:
 73        proxy = GenericProxyConfig(http_url=PROXY.SUBTITLE, https_url=PROXY.SUBTITLE) if PROXY.SUBTITLE else None
 74        logger.info(f"Fetch Subtitle via YouTubeTranscriptApi for {video_id=}, proxy={PROXY.SUBTITLE}")
 75        ytt_api = YouTubeTranscriptApi(proxy_config=proxy)
 76        resp = ytt_api.fetch(video_id, languages=["zh-CN", "zh-Hans", "zh", "zh-HK", "zh-TW", "zh-Hant", "en"])
 77        subtitles: list[dict] = resp.to_raw_data()
 78    except (IpBlocked, RequestBlocked):
 79        logger.warning(f"Subtitle API IP blocked: {video_id=}")
 80    except Exception as e:
 81        logger.error(f"Failed to get subtitle: {e}")
 82    if not subtitles:
 83        return {"error": "❌下载内嵌字幕失败\n🔄尝试使用语音转文字获取字幕"}
 84    resp = to_transcription(subtitles)
 85    if resp.get("subtitles"):
 86        resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
 87    return resp
 88
 89
 90def to_transcription(subtitles: list[dict]) -> dict:
 91    """Converts subtitles to "[hh:mm:ss] transcription" format.
 92
 93    sample subtitles = [
 94        {'text': 'hello', 'start': 0.056, 'duration': 2.88},
 95        {'text': 'world!', 'start': 2.983, 'duration': 3.244},
 96    ]
 97
 98    Returns:
 99        dict: {
100            "subtitles": "[hh:mm:ss] texts",
101            "num_chars": len(texts),
102            "reading_minutes": 2,
103            }
104    """
105    if not subtitles:
106        return {}
107
108    sentences = []
109    num_chars = 0
110
111    for subtitle in subtitles:
112        seconds = subtitle["start"]
113        sentences.append(f"[{seconds_to_time(seconds)}] {subtitle['text']}")
114        num_chars += len(subtitle["text"])
115    return {
116        "subtitles": "\n".join(sentences),
117        "num_chars": num_chars,
118        "reading_minutes": num_chars / READING_SPEED,
119    }
120
121
122def to_webvtt(subtitles: list[dict]) -> dict:
123    """(Deprecated, use `to_transcription`) Converts subtitles to WebVTT format.
124
125    sample subtitles = [
126        {'text': 'hello', 'start': 0.056, 'duration': 2.88},
127        {'text': 'world!', 'start': 2.983, 'duration': 3.244},
128    ]
129
130    Returns:
131        dict: {
132            "subtitles": "strings of subtitles in WebVTT format",
133            "num_chars": 11,
134            "num_tokens": 2,
135            }
136    """
137    if not subtitles:
138        return {}
139
140    def format_timestamp(seconds: str | float) -> str:
141        """Converts seconds to WebVTT timestamp format (hh:mm:ss.mmm)."""
142        ms = int((float(seconds) % 1) * 1000)
143        time = timedelta(seconds=int(seconds))
144        total_seconds = int(time.total_seconds())
145        hours, remainder = divmod(total_seconds, 3600)
146        minutes, seconds = divmod(remainder, 60)
147        return f"{hours:02}:{minutes:02}:{seconds:02}.{ms:03}"
148
149    try:
150        num_chars = sum(len(subtitle["text"]) for subtitle in subtitles)
151
152        vtt_output = ["WEBVTT", ""]  # WebVTT header
153        for subtitle in subtitles:
154            start = format_timestamp(subtitle["start"])
155            end = format_timestamp(subtitle["start"] + subtitle["duration"])
156            text = subtitle.get("text", "")
157            vtt_output.append(f"{start} --> {end}")
158            vtt_output.append(text)
159            vtt_output.append("")  # Add blank line between subtitles
160        # num_tokens = count_tokens("\n".join(vtt_output))
161        reading_minutes = num_chars / READING_SPEED  # minutes
162        return {"subtitles": "\n".join(vtt_output), "num_chars": num_chars, "reading_minutes": reading_minutes}
163    except Exception as e:
164        logger.error(f"Failed to convert subtitles to WebVTT: {e}")
165        return {"error": str(e)}