Commit 2f91297

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-09-01 06:30:50
fix(database): add options to control which types of messages to copy
1 parent c6d9b72
Changed files (2)
src
messages
ytdlp
src/messages/database.py
@@ -13,7 +13,7 @@ from database.database import del_db, get_db, set_db
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.utils import sender_markdown_to_html
-from utils import to_int
+from utils import to_int, true
 
 
 async def save_messages(messages: list[Message | None], key: str, metadata: dict | None = None) -> bool:
@@ -83,7 +83,20 @@ async def save_messages(messages: list[Message | None], key: str, metadata: dict
     return False
 
 
-async def copy_messages_from_db(client: Client, message: Message, key: str, kv: dict | None = None, **kwargs) -> list[Message]:
+async def copy_messages_from_db(
+    client: Client,
+    message: Message,
+    key: str,
+    kv: dict | None = None,
+    *,
+    copy_video_msg: bool = True,
+    copy_photo_msg: bool = True,
+    copy_audio_msg: bool = True,
+    copy_document_msg: bool = True,
+    copy_text_msg: bool = True,
+    copy_media_group_msg: bool = True,
+    **kwargs,
+) -> list[Message]:
     """Copy messages from database.
 
     data format:
@@ -141,15 +154,20 @@ async def copy_messages_from_db(client: Client, message: Message, key: str, kv:
             text = item.get("text")  # str or None
             if text and kwargs.get("send_from_user"):
                 text = f"{sender_markdown_to_html(kwargs['send_from_user'])}{text}"
-            if item["type"] == "text":
+            if true(copy_text_msg) and item["type"] == "text":
                 if text:
                     results.append(await client.send_message(chat_id=target_chat, text=text, reply_parameters=reply_parameters))
                 else:
                     db_msg: Message = await client.get_messages(chat_id=cid, message_ids=int(item["mid"]), replies=0)  # type: ignore
                     results.append(await client.send_message(chat_id=target_chat, text=db_msg.text, reply_parameters=reply_parameters))
-            elif item["type"] in ["photo", "audio", "video"]:
+            elif (
+                (true(copy_video_msg) and item["type"] == "video")
+                or (true(copy_photo_msg) and item["type"] == "photo")
+                or (true(copy_audio_msg) and item["type"] == "audio")
+                or (true(copy_document_msg) and item["type"] == "document")
+            ):
                 results.append(await client.copy_message(chat_id=target_chat, caption=text, from_chat_id=cid, message_id=int(item["mid"]), reply_parameters=reply_parameters))  # type: ignore
-            elif item["type"] == "media_group":
+            elif true(copy_media_group_msg) and item["type"] == "media_group":
                 results.extend(await client.copy_media_group(chat_id=target_chat, captions=text, from_chat_id=cid, message_id=int(item["mid"]), reply_parameters=reply_parameters))  # type: ignore
             else:
                 logger.warning(f"Unknown message type: {item}")
src/ytdlp/main.py
@@ -82,6 +82,7 @@ async def preview_ytdlp(
     db_key = url
     if true(use_db) and (kv := await get_db(db_key)):
         logger.debug(f"YT-DLP preview {DB.ENGINE} cache hit for key={db_key}")
+        kwargs |= {"copy_video_msg": kwargs.get("copy_video_msg", ytdlp_send_video), "copy_audio_msg": kwargs.get("copy_audio_msg", ytdlp_send_audio)}
         if db_msgs := await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return db_msgs
         await modify_progress(text=f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...", **kwargs)