Commit a7e450b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-15 13:35:21
feat(dl): use `wget` if preview is failed
1 parent d118a05
src/messages/database.py
@@ -83,7 +83,7 @@ async def save_messages(messages: list[Message | None], key: str, metadata: dict
     return False
 
 
-async def copy_messages_from_db(client: Client, message: Message, key: str, kv: dict | None = None, **kwargs) -> bool:
+async def copy_messages_from_db(client: Client, message: Message, key: str, kv: dict | None = None, **kwargs) -> list[Message]:
     """Copy messages from database.
 
     data format:
@@ -125,13 +125,13 @@ async def copy_messages_from_db(client: Client, message: Message, key: str, kv:
         kv = await get_db(key)
     if not kv.get("data"):
         logger.error(f"Wrong {DB.ENGINE} data for key={key}: {kv}")
-        return False
+        return []
     data: list[dict] = kv.get("data", [])
     if isinstance(data, str):
         data = json.loads(data)
     logger.debug(f"Sending {len(data)} messages from {DB.ENGINE}: {data}")
     await modify_progress(text=f"💾在{DB.ENGINE}中查到缓存, 正在转发{len(data)}条消息...", **kwargs)
-    results = []
+    results: list[Message] = []
     try:
         for idx, item in enumerate(sorted(data, key=custom_sort)):
             cid = to_int(item["cid"])
@@ -156,10 +156,10 @@ async def copy_messages_from_db(client: Client, message: Message, key: str, kv:
     except Exception as e:
         logger.error(f"Failed to copy messages for key={key} from {DB.ENGINE}: {e}")
         await del_db(key)
-        return False
+        return []
     if all(isinstance(x, Message) for x in results):
         logger.success(f"Successfully copied {len(results)} messages for key={key} from {DB.ENGINE}")
         await modify_progress(del_status=True, **kwargs)
-        return True
+        return results
     await del_db(key)
-    return False
+    return []
src/others/download_external.py
@@ -1,8 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
-
-import re
 from pathlib import Path
 
 from loguru import logger
@@ -11,13 +8,14 @@ from pyrogram.types import Message
 
 from config import MAX_FILE_BYTES, PREFIX
 from database import guess_mime
+from llm.utils import convert_md
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import equal_prefix, get_reply_to, startswith_prefix
 from multimedia import is_valid_video_or_audio, validate_img
 from networking import download_file
-from utils import https_url, readable_size, to_int
+from utils import find_url, publish_telegraph, readable_size, to_int
 
 HELP = f"""
 ⏬**下载文件**
@@ -27,16 +25,16 @@ HELP = f"""
 """
 
 
-async def download_url_in_message(client: Client, message: Message, **kwargs):
+async def download_url_in_message(client: Client, message: Message, extra_prefix: list[str] | None = None, **kwargs):
     """Download the url from the message."""
     info = parse_msg(message)
-    if not startswith_prefix(info["text"], prefix=[PREFIX.WGET]):
+    extra_prefix = extra_prefix or []
+    if not startswith_prefix(info["text"], prefix=[PREFIX.WGET, *extra_prefix]):
         return
     # send docs if message == "/wget", without reply
     if equal_prefix(message.text, prefix=[PREFIX.WGET]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
-
     # reply a message with /wget
     if message.reply_to_message:
         message = message.reply_to_message
@@ -47,36 +45,46 @@ async def download_url_in_message(client: Client, message: Message, **kwargs):
     reply_msg_id = kwargs.get("reply_msg_id", 0)
     reply_parameters = get_reply_to(message.id, reply_msg_id)
 
-    regex = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))"  # noqa: RUF001
-    if matched := re.findall(regex, info["text"]):
-        url = https_url(matched[0][0])
-        logger.debug(f"URL found from message text: {url}")
-
+    if not (url := find_url(info["text"])):
+        await message.reply_text("❌未找到URL")
+        return
+    caption = f"🔗[原始链接]({url})"
     msg = f"⏬开始下载:\n{url}"
-    if kwargs.get("show_progress"):
+    if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
+    else:
+        await modify_progress(text=msg, force_update=True, **kwargs)
     success = False
     try:
         path = await download_file(url, workers_proxy=True, **kwargs)
-        suffix = Path(path).suffix
-        if mime := guess_mime(path):
+        path = Path(path)
+        suffix = path.suffix
+        if (mime := guess_mime(path)) and mime != "text/plain":
             suffix = "." + mime.split("/")[-1]
-        if Path(path).suffix != suffix:
-            Path(path).rename(Path(path).with_suffix(suffix))
-            path = Path(path).with_suffix(suffix)
+        if path.suffix != suffix:
+            path.rename(path.with_suffix(suffix))
+            path = path.with_suffix(suffix)
         if img := validate_img(path, force_jpg=False, delete=False):
             await modify_progress(text=f"🏞图片下载成功: {readable_size(path=img)}", force_update=True, **kwargs)
-            success = await send2tg(client, message, target_chat, reply_msg_id, texts=url, media=[{"photo": img}])
-        elif Path(path).suffix in [".m4a", ".mp3", ".wav", ".ogg", ".opus", ".flac", ".aac"]:
+            success = await send2tg(client, message, target_chat, reply_msg_id, texts=caption, media=[{"photo": img}])
+        elif path.suffix in [".m4a", ".mp3", ".wav", ".ogg", ".opus", ".flac", ".aac"]:
             await modify_progress(text=f"🎧音频下载成功: {readable_size(path=path)}", force_update=True, **kwargs)
-            success = await client.send_audio(target_chat, Path(path).as_posix(), caption=url, reply_parameters=reply_parameters)
+            success = await client.send_audio(target_chat, path.as_posix(), caption=caption, reply_parameters=reply_parameters)
         elif is_valid_video_or_audio(path, delete=False):
             await modify_progress(text=f"🎬视频下载成功: {readable_size(path=path)}", force_update=True, **kwargs)
-            success = await send2tg(client, message, target_chat, reply_msg_id, texts=url, media=[{"video": path}])
-        elif Path(path).stat().st_size < MAX_FILE_BYTES:
+            success = await send2tg(client, message, target_chat, reply_msg_id, texts=caption, media=[{"video": path}])
+        elif path.stat().st_size < MAX_FILE_BYTES:
             await modify_progress(text=f"💾文件下载成功: {readable_size(path=path)}", force_update=True, **kwargs)
-            success = await client.send_document(target_chat, Path(path).as_posix(), caption=url, reply_parameters=reply_parameters)
+            if suffix == ".html":
+                markdown = convert_md(path)
+                markdown_path = path.with_suffix(".md")
+                markdown_path.write_text(markdown)
+                if telegraph_url := await publish_telegraph(title="全文内容", texts=markdown, author=info["full_name"], url=url):
+                    caption += f"\n⚡️[Telegraph即时预览]({telegraph_url})"
+                success = await client.send_document(target_chat, markdown_path.as_posix(), caption=caption, reply_parameters=reply_parameters)
+            else:
+                success = await client.send_document(target_chat, path.as_posix(), caption=caption, reply_parameters=reply_parameters)
         else:
             await modify_progress(text=f"❌文件大小: {readable_size(path=path)} 超出限制\nTelegram只允许上传小于{round(MAX_FILE_BYTES / 1024 / 1024)}MB的文件", force_update=True, **kwargs)
     except Exception as e:
src/preview/ytdlp.py
@@ -70,7 +70,7 @@ async def preview_ytdlp(
     transcription_force_file: bool = False,
     to_telegraph: bool = True,
     **kwargs,
-):
+) -> list[Message]:
     """Preview ytdlp link in the message.
 
     Args:
@@ -98,8 +98,8 @@ async def preview_ytdlp(
     db_key = url
     if use_db and (kv := await get_db(db_key)):
         logger.debug(f"YT-DLP preview {DB.ENGINE} cache hit for key={db_key}")
-        if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
-            return
+        if db_msgs := await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
+            return db_msgs
         await modify_progress(text=f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...", **kwargs)
 
     # set download & upload options
@@ -143,12 +143,12 @@ async def preview_ytdlp(
             await modify_progress(del_status=True, **kwargs)
             raise ProxyError(ytdlp_error)
         await modify_progress(text=ytdlp_error, force_update=True, **kwargs)
-        return
+        return []
     await modify_progress(text=f"⏬正在下载:\n{info['summary']}", force_update=True, **kwargs)
     ytdlp_error = await download_video_async(json_file, ydl_opts)
     if ytdlp_error:
         await modify_progress(text=ytdlp_error, force_update=True, **kwargs)
-        return
+        return []
     video_path = info.get("video_path", Path(""))
     audio_path = info.get("audio_path", Path(""))
     # only save messages when both video and audio are uploaded
@@ -202,7 +202,7 @@ async def preview_ytdlp(
         if await count_without_entities(f"{texts}{comment}") < CAPTION_LENGTH:
             texts += comment
     texts = texts.strip()
-    sent_messages: list[Message | None] = []  # 把发送的消息都记录下来
+    sent_messages = []  # 把发送的消息都记录下来
     target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
     target_chat = to_int(target_chat)
     reply_msg_id = kwargs.get("reply_msg_id", 0)
@@ -280,6 +280,7 @@ async def preview_ytdlp(
 
     Path(json_file).unlink(missing_ok=True)
     cleanup_ytdlp(info["id"])
+    return sent_messages
 
 
 def get_ytdlp_proxy(url: str = "", platform: str = "") -> str | None:
src/handler.py
@@ -181,11 +181,12 @@ async def handle_social_media(
         ]
     )
     info = parse_msg(message)
+    this_msg = message
     this_texts = info["text"]  # texts of the trigger message
     if startswith_prefix(this_texts, prefix=ignore_prefix):
-        return
+        return None
     if need_prefix and not startswith_prefix(this_texts, prefix=[*cmd_prefix, "/retry"]):
-        return
+        return None
     kwargs |= params_from_msg_text(this_texts)  # merge the parameters from the message text
     if true(kwargs.get("target_chat")):
         kwargs["target_chat"] = to_int(kwargs["target_chat"])
@@ -194,16 +195,14 @@ async def handle_social_media(
         # without reply, send docs if message only contains prefix command
         if not message.reply_to_message:
             help_msg = get_social_media_help(info["cid"], info["ctype"], cmd_prefix)
-            await send2tg(client, message, texts=help_msg, **kwargs)
-            return
+            return await send2tg(client, message, texts=help_msg, **kwargs)
         # with reply, treat the reply_msg as the trigger to preview social media link
         message = message.reply_to_message
         info = parse_msg(message, silent=True)  # parse again
 
     warn_msg = None
     if not need_prefix and startswith_prefix(this_texts, prefix=cmd_prefix, ignore_prefix=ignore_prefix):
-        warn_msg = await send2tg(client, message, texts="⚠️本会话中可直接发送链接, 无需添加命令前缀\n⚠️No need to add command prefix in this chat.", **kwargs)
-        warn_msg = warn_msg[0]
+        warn_msg = await client.send_message(info["cid"], text="⚠️本会话中可直接发送链接, 无需添加命令前缀\n⚠️No need to add command prefix in this chat.")
 
     # add send_from_user.
     if prepend_sender_user:
@@ -217,33 +216,38 @@ async def handle_social_media(
         if startswith_prefix(this_texts, prefix=["/retry"], ignore_prefix=ignore_prefix):
             await del_db(matched["db_key"])
         if douyin and matched["platform"] == "douyin":
-            await preview_douyin(client, message, **kwargs)
+            return await preview_douyin(client, message, **kwargs)
         if tiktok and matched["platform"] == "tiktok":
-            await preview_douyin(client, message, **kwargs)
+            return await preview_douyin(client, message, **kwargs)
         if instagram and matched["platform"] == "instagram":
-            await preview_instagram(client, message, **kwargs)
+            return await preview_instagram(client, message, **kwargs)
         if twitter and matched["platform"] in ["x", "twitter", "fxtwitter", "fixupx"]:
-            await preview_twitter(client, message, **kwargs)
+            return await preview_twitter(client, message, **kwargs)
         if weibo and matched["platform"] == "weibo":
-            await preview_weibo(client, message, **kwargs)
+            return await preview_weibo(client, message, **kwargs)
         if xhs and matched["platform"] == "xiaohongshu":
-            await preview_xhs(client, message, **kwargs)
+            return await preview_xhs(client, message, **kwargs)
         if xhs and matched["platform"] == "wechat":
-            await preview_wechat(client, message, **kwargs)
+            return await preview_wechat(client, message, **kwargs)
         if reddit and matched["platform"] == "reddit":
-            await preview_reddit(client, message, **kwargs)
+            return await preview_reddit(client, message, **kwargs)
         if matched["platform"].startswith("bilibili-"):  # this is not bilibili video, for videos, use yt-dlp
-            await preview_bilibili(client, message, **kwargs)
+            return await preview_bilibili(client, message, **kwargs)
+        sent_messages = []
         try:
             if ytdlp and any(matched["platform"] == x for x in ["bilibili", "youtube", "ytdlp"]):
-                await preview_ytdlp(client, message, **kwargs)
+                sent_messages = await preview_ytdlp(client, message, **kwargs)
         except ProxyError:
             logger.error(f"🚫{matched['platform']}代理错误")
             if PROXY.YTDLP_FALLBACK:
                 logger.warning(f"🔄使用备用代理{PROXY.YTDLP_FALLBACK}")
-                await preview_ytdlp(client, message, proxy=PROXY.YTDLP_FALLBACK, **kwargs)
+                sent_messages = await preview_ytdlp(client, message, proxy=PROXY.YTDLP_FALLBACK, **kwargs)
         if warn_msg:
             await warn_msg.delete()
+        if not sent_messages and startswith_prefix(this_texts, prefix=cmd_prefix):
+            if kwargs.get("show_progress"):
+                kwargs["progress"] = await client.send_message(info["cid"], text="⚠️暂时不支持解析链接, 尝试直接下载该网页")
+            await download_url_in_message(client, this_msg, extra_prefix=cmd_prefix, **kwargs)
 
     except Exception as e:
         logger.exception(e)
src/networking.py
@@ -384,7 +384,7 @@ async def match_social_media_link(text: str, *, flatten_first: bool = True) -> d
     # if all above pre-defined patterns failed, try to match ytdlp link
     if urls := match_urls(text):
         for url in urls:
-            if any(x in url.lower() for x in ["bilibili", "douyin", "instagram", "tiktok", "twitter", "weibo", "xiaohongshu", "youtube"]):
+            if any(x in url.lower() for x in ["bilibili", "douyin", "instagram", "tiktok", "twitter", "weibo", "xiaohongshu", "reddit", "youtube"]):
                 # handled above
                 continue
             if is_supported_by_ytdlp(url):
src/utils.py
@@ -201,6 +201,17 @@ def readable_size(num_bytes: str | float = 0, path: str | Path | None = None) ->
     return f"{num_bytes:.1f} MB"
 
 
+def find_url(text: str) -> str:
+    if not isinstance(text, str):
+        return ""
+    regex = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))"
+    if matched := re.findall(regex, text):
+        url = matched[0][0]
+        logger.debug(f"URL found from message text: {url}")
+        return url
+    return ""
+
+
 def https_url(url: str) -> str:
     return "https://" + str(url).removeprefix("https://").removeprefix("http://").lstrip("/").rstrip("/")
 
@@ -389,6 +400,9 @@ async def publish_telegraph(title: str, texts: str | None = None, html: str | No
         return ""
     if texts:
         html = markdown.markdown(texts)
+        # Revise Telegraph Tags
+        html = html.replace("<h1>", "<h3>").replace("</h1>", "</h3>")
+        html = html.replace("<h2>", "<h3>").replace("</h2>", "</h3>")
     telegraph = Telegraph(access_token=TOKEN.TELEGRAPH)
     account_info = {}
     if not (author and url):
@@ -453,6 +467,8 @@ if __name__ == "__main__":
     print(is_supported_by_ytdlp("https://www.bilibili.com/video/BV15n61YtEmk"))
     print(is_supported_by_ytdlp("https://t.me/c/1744444199/2475260"))
     print(is_supported_by_ytdlp("https://test.com/"))
+    print(find_url("https://test.com/"))
+    print(find_url("test.com/"))
 
     # assert av2bv("av113503016851915") == "BV1Y4UHYyE2z"
     # assert av2bv("113503016851915") == "BV1Y4UHYyE2z"