main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3from datetime import timedelta
4
5from loguru import logger
6from pyrogram.client import Client
7from pyrogram.types import Message
8from youtube_transcript_api import IpBlocked, RequestBlocked, YouTubeTranscriptApi
9from youtube_transcript_api.proxies import GenericProxyConfig
10
11from asr.corrector import asr_corrector
12from config import PREFIX, PROXY, READING_SPEED, cache
13from messages.parser import parse_msg
14from messages.utils import startswith_prefix
15from networking import match_social_media_link
16from preview.bilibili import bilibili_subtitle_and_summary
17from utils import seconds_to_time
18
19
20async def match_url(client: Client, message: Message) -> str:
21 """Find valid url from message."""
22 info = parse_msg(message, silent=True)
23 if not startswith_prefix(info["text"], prefix=[PREFIX.SUBTITLE]):
24 return ""
25 # /subtitle "link"
26 matched = await match_social_media_link(info["text"])
27 if matched["platform"] in ["youtube", "bilibili"]:
28 return matched["url"]
29 for entity_url in info["entity_urls"]:
30 matched = await match_social_media_link(entity_url)
31 if matched["platform"] in ["youtube", "bilibili"]:
32 return matched["url"]
33
34 # is replying to message?
35 if not message.reply_to_message:
36 return ""
37 reply_message = message.reply_to_message
38 # if reply to a media_group, fetch all messages in the group
39 reply_messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [reply_message]
40 for msg in reply_messages:
41 info = parse_msg(msg, silent=True)
42 matched = await match_social_media_link(info["text"])
43 if matched["platform"] in ["youtube", "bilibili"]:
44 return matched["url"]
45 for entity_url in info["entity_urls"]:
46 matched = await match_social_media_link(entity_url)
47 if matched["platform"] in ["youtube", "bilibili"]:
48 return matched["url"]
49 return ""
50
51
52@cache.memoize(ttl=120)
53async def fetch_subtitle(url: str, reference: str = "") -> dict:
54 """Fetch subtitles from Bilibili or YouTube.
55
56 Returns:
57 dict: {
58 "subtitles": "[minute:second] texts",
59 "num_chars": len(texts),
60 "reading_minutes": 2,
61 }
62 """
63 subtitles = []
64 matched = await match_social_media_link(url)
65 if matched["platform"] == "bilibili":
66 resp = await bilibili_subtitle_and_summary(url)
67 if resp.get("subtitles"):
68 resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
69 return resp
70
71 video_id = matched["vid"]
72 try:
73 proxy = GenericProxyConfig(http_url=PROXY.SUBTITLE, https_url=PROXY.SUBTITLE) if PROXY.SUBTITLE else None
74 logger.info(f"Fetch Subtitle via YouTubeTranscriptApi for {video_id=}, proxy={PROXY.SUBTITLE}")
75 ytt_api = YouTubeTranscriptApi(proxy_config=proxy)
76 resp = ytt_api.fetch(video_id, languages=["zh-CN", "zh-Hans", "zh", "zh-HK", "zh-TW", "zh-Hant", "en"])
77 subtitles: list[dict] = resp.to_raw_data()
78 except (IpBlocked, RequestBlocked):
79 logger.warning(f"Subtitle API IP blocked: {video_id=}")
80 except Exception as e:
81 logger.error(f"Failed to get subtitle: {e}")
82 if not subtitles:
83 return {"error": "❌下载内嵌字幕失败\n🔄尝试使用语音转文字获取字幕"}
84 resp = to_transcription(subtitles)
85 if resp.get("subtitles"):
86 resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
87 return resp
88
89
90def to_transcription(subtitles: list[dict]) -> dict:
91 """Converts subtitles to "[hh:mm:ss] transcription" format.
92
93 sample subtitles = [
94 {'text': 'hello', 'start': 0.056, 'duration': 2.88},
95 {'text': 'world!', 'start': 2.983, 'duration': 3.244},
96 ]
97
98 Returns:
99 dict: {
100 "subtitles": "[hh:mm:ss] texts",
101 "num_chars": len(texts),
102 "reading_minutes": 2,
103 }
104 """
105 if not subtitles:
106 return {}
107
108 sentences = []
109 num_chars = 0
110
111 for subtitle in subtitles:
112 seconds = subtitle["start"]
113 sentences.append(f"[{seconds_to_time(seconds)}] {subtitle['text']}")
114 num_chars += len(subtitle["text"])
115 return {
116 "subtitles": "\n".join(sentences),
117 "num_chars": num_chars,
118 "reading_minutes": num_chars / READING_SPEED,
119 }
120
121
122def to_webvtt(subtitles: list[dict]) -> dict:
123 """(Deprecated, use `to_transcription`) Converts subtitles to WebVTT format.
124
125 sample subtitles = [
126 {'text': 'hello', 'start': 0.056, 'duration': 2.88},
127 {'text': 'world!', 'start': 2.983, 'duration': 3.244},
128 ]
129
130 Returns:
131 dict: {
132 "subtitles": "strings of subtitles in WebVTT format",
133 "num_chars": 11,
134 "num_tokens": 2,
135 }
136 """
137 if not subtitles:
138 return {}
139
140 def format_timestamp(seconds: str | float) -> str:
141 """Converts seconds to WebVTT timestamp format (hh:mm:ss.mmm)."""
142 ms = int((float(seconds) % 1) * 1000)
143 time = timedelta(seconds=int(seconds))
144 total_seconds = int(time.total_seconds())
145 hours, remainder = divmod(total_seconds, 3600)
146 minutes, seconds = divmod(remainder, 60)
147 return f"{hours:02}:{minutes:02}:{seconds:02}.{ms:03}"
148
149 try:
150 num_chars = sum(len(subtitle["text"]) for subtitle in subtitles)
151
152 vtt_output = ["WEBVTT", ""] # WebVTT header
153 for subtitle in subtitles:
154 start = format_timestamp(subtitle["start"])
155 end = format_timestamp(subtitle["start"] + subtitle["duration"])
156 text = subtitle.get("text", "")
157 vtt_output.append(f"{start} --> {end}")
158 vtt_output.append(text)
159 vtt_output.append("") # Add blank line between subtitles
160 # num_tokens = count_tokens("\n".join(vtt_output))
161 reading_minutes = num_chars / READING_SPEED # minutes
162 return {"subtitles": "\n".join(vtt_output), "num_chars": num_chars, "reading_minutes": reading_minutes}
163 except Exception as e:
164 logger.error(f"Failed to convert subtitles to WebVTT: {e}")
165 return {"error": str(e)}