Commit 012fa1d

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-01-30 12:50:56
feat(ytdlp): support all ytdlp links
1 parent 8da8a97
src/preview/ytdlp.py
@@ -2,12 +2,13 @@
 # -*- coding: utf-8 -*-
 import asyncio
 import json
+import os
 import re
 import threading
 import time
 import warnings
 from pathlib import Path
-from urllib.parse import quote_plus, unquote_plus
+from urllib.parse import quote_plus, unquote_plus, urlparse
 
 from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
 from loguru import logger
@@ -61,7 +62,7 @@ async def preview_ytdlp(
         youtube_comments_provider (str, optional): The youtube comments extractor: "free" or "false".
         proxy (str, optional): Proxy to use. Defaults to None.
     """
-    logger.trace(f"{url=} {proxy=} {kwargs=}")
+    logger.trace(f"{url=} {kwargs=}")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析{platform}链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
@@ -77,7 +78,8 @@ async def preview_ytdlp(
         ytdlp_send_video = False
     if not ytdlp_send_video:
         ytdlp_send_audio = True
-
+    if proxy is None:
+        proxy = get_ytdlp_proxy(url)
     ydl_opts = {
         "paths": {"home": DOWNLOAD_DIR},
         "cachedir": DOWNLOAD_DIR,
@@ -221,12 +223,17 @@ async def preview_ytdlp(
     cleanup_ytdlp(info["id"])
 
 
-def get_ytdlp_proxy(platform: str) -> str | None:
-    if platform == "bilibili":
-        proxy = PROXY.BILIBILI
-    elif platform == "youtube":
-        proxy = PROXY.YOUTUBE
+def get_ytdlp_proxy(url: str = "", platform: str = "") -> str | None:
+    if platform:
+        proxy = os.getenv(f"YTDLP_PROXY_{platform}".upper())
     else:
+        parsed = urlparse(url)
+        host = parsed.netloc  # www.youtube.com
+        platform = host.split(".")[-2]  # youtube
+        proxy = os.getenv(f"YTDLP_PROXY_{platform}".upper())
+    if proxy is None:  # fallback to default proxy is unset
+        proxy = PROXY.YTDLP
+    if proxy == "":  # empty string means no proxy
         proxy = None
     logger.debug(f"YTDLP Proxy of {platform}: {proxy}")
     return proxy
@@ -504,7 +511,7 @@ async def get_youtube_comments(vid: str | None, provider: str = PROVIDER.YOUTUBE
     params = {"key": TOKEN.YOUTUBE_API_KEY, "maxResults": 100, "textFormat": "plainText", "part": "snippet", "videoId": vid}
     comments = []
     try:
-        resp = await hx_req(api, proxy=PROXY.YOUTUBE, params=params, check_has_kv=["items"])
+        resp = await hx_req(api, proxy=get_ytdlp_proxy(platform="youtube"), params=params, check_has_kv=["items"])
         if resp.status_code != 200:
             logger.warning(f"YouTube Comments API failed: {resp}")
             return []
src/config.py
@@ -27,7 +27,6 @@ class ENABLE:
     AI_SUMMARY = os.getenv("ENABLE_AI_SUMMARY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     ASR = os.getenv("ENABLE_ASR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     AUDIO = os.getenv("ENABLE_AUDIO", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
-    BILIBILI = os.getenv("ENABLE_BILIBILI", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     COMBINATION = os.getenv("ENABLE_COMBINATION", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     CRONTAB = os.getenv("ENABLE_CRONTAB", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     DOUYIN = os.getenv("ENABLE_DOUYIN", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -40,7 +39,7 @@ class ENABLE:
     WEIBO = os.getenv("ENABLE_WEIBO", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     WGET = os.getenv("ENABLE_WGET", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     XHS = os.getenv("ENABLE_XHS", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
-    YOUTUBE = os.getenv("ENABLE_YOUTUBE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    YTDLP = os.getenv("ENABLE_YTDLP", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     RAW_IMG_CONVERT = os.getenv("ENABLE_RAW_IMG_CONVERT", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     GROUPS = os.getenv("ENABLE_GROUPS", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     CHANNELS = os.getenv("ENABLE_CHANNELS", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -109,9 +108,8 @@ class PROXY:  # format: socks5://127.0.0.1:7890
     DOWNLOAD = os.getenv("DOWNLOAD_PROXY", None)
     WEIBO_COOKIE = os.getenv("WEIBO_COOKIE_PROXY", None)  # Weibo visitor cookie
     YTDLP = os.getenv("YTDLP_PROXY", None)  # general proxy for ytdlp
-    YTDLP_FALLBACK = os.getenv("YTDLP_FALLBACK_PROXY", None)
-    BILIBILI = os.getenv("BILIBILI_PROXY", None)
-    YOUTUBE = os.getenv("YOUTUBE_PROXY", None)
+    YTDLP_FALLBACK = os.getenv("YTDLP_PROXY_FALLBACK", None)  # fallback proxy for ytdlp
+    # for ytdlp proxy of specific sites (Like Bilibili), use this format: YTDLP_PROXY_BILIBILI
 
 
 class COOKIE:  # See: https://github.com/easychen/CookieCloud
src/handler.py
@@ -16,7 +16,7 @@ from llm.summary import ai_summary
 from messages.parser import parse_msg
 from messages.sender import send2tg
 from messages.utils import equal_prefix, startswith_prefix
-from networking import flatten_rediercts, match_social_media_link
+from networking import match_social_media_link
 from others.combine_history import combine_history
 from others.download_external import download_url_in_message
 from others.extract_audio import extract_audio_file
@@ -27,7 +27,7 @@ from preview.instagram import preview_instagram
 from preview.twitter import preview_twitter
 from preview.weibo import preview_weibo
 from preview.xiaohongshu import preview_xhs
-from preview.ytdlp import ProxyError, get_ytdlp_proxy, preview_ytdlp
+from preview.ytdlp import ProxyError, preview_ytdlp
 from utils import to_int, true
 
 
@@ -110,8 +110,7 @@ async def handle_social_media(
     twitter: bool = True,
     weibo: bool = True,
     xhs: bool = True,
-    bilibili: bool = True,
-    youtube: bool = True,
+    ytdlp: bool = True,
     show_progress: bool = True,
     detail_progress: bool = False,
     **kwargs,
@@ -168,8 +167,7 @@ async def handle_social_media(
         # Caution: this format should be consistent with `save_messages` function in `message.database.py`
         kwargs["send_from_user"] = f"👤[@{info['full_name']}](tg://user?id={info['uid']})//"
     try:
-        texts = await flatten_rediercts(info["text"])
-        matched = await match_social_media_link(texts)  # match "platform" and "url" (and other info)
+        matched = await match_social_media_link(info["text"], flatten_first=True)  # match "platform" and "url" (and other info)
         kwargs |= matched
         if startswith_prefix(this_texts, prefix=["/retry"], ignore_prefix=ignore_prefix):
             await del_db(matched["db_key"])
@@ -186,10 +184,8 @@ async def handle_social_media(
         if xhs and matched["platform"] == "xiaohongshu" and ENABLE.XHS:
             await preview_xhs(client, message, **kwargs)
         try:
-            if bilibili and matched["platform"] == "bilibili" and ENABLE.BILIBILI:
-                await preview_ytdlp(client, message, proxy=get_ytdlp_proxy("bilibili"), **kwargs)
-            if youtube and matched["platform"] == "youtube" and ENABLE.YOUTUBE:
-                await preview_ytdlp(client, message, proxy=get_ytdlp_proxy("youtube"), **kwargs)
+            if ytdlp and matched["platform"] == "ytdlp" and ENABLE.YTDLP:
+                await preview_ytdlp(client, message, **kwargs)
         except ProxyError:
             logger.error(f"🚫{matched['platform']}代理错误")
             if PROXY.YTDLP_FALLBACK:
@@ -250,10 +246,6 @@ def get_social_media_help(cmd_prefix: list[str] | None = None, ignore_prefix: li
     if prefixes:
         msg += f" 前缀: {', '.join(prefixes)}"
         msg += "\n🔄使用 `/retry` 回复消息强制重试"
-    if ENABLE.YOUTUBE:
-        msg += "\n🔴油管"
-    if ENABLE.BILIBILI:
-        msg += "\n🅱️哔哩哔哩"
     if ENABLE.TWITTER:
         msg += "\n🕊推特"
     if ENABLE.WEIBO:
@@ -266,6 +258,10 @@ def get_social_media_help(cmd_prefix: list[str] | None = None, ignore_prefix: li
         msg += "\n🎶TikTok"
     if ENABLE.INSTAGRAM:
         msg += "\n🏞Instagram"
+    if ENABLE.YTDLP:
+        msg += "\n🔴油管"
+        msg += "\n🅱️哔哩哔哩"
+        msg += "\n🆕和所有yt-dlp支持的链接\n"
     if ENABLE.ASR:
         msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
     if ENABLE.AUDIO:
src/networking.py
@@ -17,7 +17,7 @@ from loguru import logger
 from config import DOWNLOAD_DIR, PROXY, UA, cache, semaphore
 from messages.progress import modify_progress
 from messages.utils import summay_media
-from utils import bare_url, https_url, readable_size
+from utils import bare_url, https_url, is_supported_by_ytdlp, match_urls, readable_size
 
 # ruff: noqa: RUF001
 MOBILE_HEADERS = {
@@ -313,6 +313,7 @@ async def match_social_media_link(text: str, *, flatten_first: bool = False) ->
                 db_key: The key to store in the cache.
     #! TODO: Handle multiple links in one message.
     """
+    text = str(text)
     if flatten_first:
         text = await flatten_rediercts(text)
     matched_info = {"platform": ""}
@@ -380,24 +381,32 @@ async def match_social_media_link(text: str, *, flatten_first: bool = False) ->
         matched_info = {"url": f"https://www.xiaohongshu.com/explore/{post_id}?xsec_token={xsec}", "db_key": f"www.xiaohongshu.com/explore/{post_id}", "xsec": xsec, "platform": "xiaohongshu"}
 
     # https://www.bilibili.com/video/BV1TC411J7PK
-    if matched := re.search(r"(https?://)?(:?m\.|www\.)?bilibili\.com/video/([^,,.。\s]+)", str(text)):
+    if matched := re.search(r"(https?://)?(:?m\.|www\.)?bilibili\.com/video/([^,,.。\s]+)", text):
         base_url = matched.group(0).split("?")[0]
         bvid = Path(base_url).stem
         queries = parse_qs(urlparse(matched.group(0)).query)
         pid = queries.get("p", ["1"])[0]
         url = f"https://www.bilibili.com/video/{bvid}?p={pid}".removesuffix("?p=1")
-        matched_info = {"url": url, "db_key": bare_url(url), "bvid": bvid, "pid": pid, "platform": "bilibili"}
+        matched_info = {"url": url, "db_key": bare_url(url), "bvid": bvid, "pid": pid, "platform": "ytdlp"}
 
     # https://www.youtube.com/watch?v=D6aE2E0RHTc
-    if matched := re.search(r"(https?://)?(:?m\.|www\.)?youtube\.com/watch([^,,.。\s]+)", str(text)):
+    if matched := re.search(r"(https?://)?(:?m\.|www\.)?youtube\.com/watch([^,,.。\s]+)", text):
         queries = parse_qs(urlparse(matched.group(0)).query)
         if vid := queries.get("v", [""])[0]:
-            matched_info = {"url": f"https://www.youtube.com/watch?v={vid}", "db_key": f"www.youtube.com/watch?v={vid}", "vid": vid, "platform": "youtube"}
+            matched_info = {"url": f"https://www.youtube.com/watch?v={vid}", "db_key": f"www.youtube.com/watch?v={vid}", "vid": vid, "platform": "ytdlp"}
     # https://youtube.com/shorts/lFKHbluAlJw
-    if matched := re.search(r"(https?://)?(:?m\.|www\.)?youtube\.com/shorts/([^,,.。?\s]+)", str(text)):
+    if matched := re.search(r"(https?://)?(:?m\.|www\.)?youtube\.com/shorts/([^,,.。?\s]+)", text):
         vid = matched.group(3)
-        matched_info = {"url": f"https://www.youtube.com/watch?v={vid}", "db_key": f"www.youtube.com/watch?v={vid}", "vid": vid, "platform": "youtube"}
-
+        matched_info = {"url": f"https://www.youtube.com/watch?v={vid}", "db_key": f"www.youtube.com/watch?v={vid}", "vid": vid, "platform": "ytdlp"}
+
+    # if all above pre-defined patterns failed, try to match ytdlp link
+    if not matched_info["platform"] and (urls := match_urls(text)):
+        for url in urls:
+            if any(x in url.lower() for x in ["bilibili", "youtube"]):  # handled above
+                continue
+            if is_supported_by_ytdlp(url):
+                matched_info = {"url": url, "db_key": bare_url(url), "platform": "ytdlp"}
+                break
     if matched_info["platform"]:
         logger.success(f"Matched: {matched_info}")
     return matched_info
src/utils.py
@@ -3,6 +3,7 @@
 
 import json
 import random
+import re
 import string
 from datetime import UTC, datetime
 from pathlib import Path
@@ -12,9 +13,12 @@ from zoneinfo import ZoneInfo
 from bs4 import PageElement
 from loguru import logger
 from pyrogram.client import Client
+from yt_dlp.extractor import gen_extractors
 
 from config import DOWNLOAD_DIR, TEXT_LENGTH, TZ, cache
 
+# ruff: noqa: RUF001
+
 
 def nowdt(tz: str = "UTC") -> datetime:
     return datetime.now(ZoneInfo(tz))
@@ -218,6 +222,20 @@ async def i_am_bot(client: Client) -> bool:
     return me.is_bot
 
 
+def match_urls(text: str) -> list[str]:
+    """Match all urls in a text."""
+    res = re.findall(
+        r'(?i)\b((?:[a-z][\w-]+:(?:/{1,3}|[a-z0-9%])|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))',
+        str(text),
+    )
+    return [https_url(x[0]) for x in res]
+
+
+def is_supported_by_ytdlp(url: str) -> bool:
+    extractors = gen_extractors()
+    return any(extractor.suitable(url) for extractor in extractors)
+
+
 def unicode_to_ascii(text: str | float) -> str:
     if not text:
         return ""
@@ -257,3 +275,4 @@ if __name__ == "__main__":
     print(unicode_to_ascii("test"))
     print(ascii_to_unicode("1.1"))
     print(ascii_to_unicode("test"))
+    print(match_urls("http://a.com/BmT8gZ 匹配不到就删除了https://b.com/MxRdMO"))