Commit 47a00cc

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-10 09:09:31
feat(history): add fine-grained control for saving messages
1 parent eb8812e
src/history/d1.py
@@ -14,8 +14,8 @@ from pyrogram.types import Chat, Message
 
 from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
 from database.d1 import create_d1_database, create_d1_table, query_d1
+from history.utils import check_save_history
 from messages.parser import parse_msg
-from permission import check_save_history
 from utils import i_am_bot, nowdt
 
 # ruff: noqa: S608
src/history/query.py
@@ -12,13 +12,13 @@ from pyrogram.types import Message
 from config import PREFIX, cutter
 from database.turso import turso_exec
 from history.turso import get_table_name
-from history.utils import TURSO_KWARGS, get_uid_by_username, is_admin, list_chat_ids, mtype_emoji
+from history.utils import TURSO_KWARGS, check_save_history, get_uid_by_username, is_admin, list_chat_ids, mtype_emoji
 from llm.utils import convert_html
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import blockquote, equal_prefix, smart_split, startswith_prefix
-from permission import check_save_history, slim_cid
+from permission import slim_cid
 from publish import publish_telegraph
 
 HELP = f"""🗣**查询当前对话聊天记录**
src/history/sync.py
@@ -12,7 +12,7 @@ from history.turso import backup_chat_history_to_turso, sync_history_to_turso
 async def sync_chat_history(client: Client, message: Message) -> None:
     if not HISTORY.ENABLE:
         return
-    if HISTORY.ENGINE.upper() == "D1":
+    if HISTORY.ENGINE.upper() == "D1":  # Deprecated
         await sync_history_to_d1(client, message)
     if HISTORY.ENGINE.upper() == "TURSO":
         await sync_history_to_turso(client, message)
src/history/turso.py
@@ -12,9 +12,8 @@ from pyrogram.types import Message
 
 from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
 from database.turso import insert_statement, turso_create_table, turso_exec, turso_list_tables
-from history.utils import TURSO_KWARGS, chat_info
+from history.utils import TURSO_KWARGS, chat_info, check_save_history, fine_grained_check
 from messages.parser import parse_msg
-from permission import check_save_history
 from utils import i_am_bot, nowdt
 
 # ruff: noqa: S608
@@ -27,7 +26,7 @@ async def sync_history_to_turso(client: Client, message: Message) -> None:
     if not HISTORY.TURSO_ENABLE:
         return
     info = parse_msg(message, silent=True)
-    if not check_save_history(info["ctype"], info["cid"]):
+    if not check_save_history(info["ctype"], info["cid"]) or not fine_grained_check(info):
         return
     table_name = await get_table_name(client, info["cid"])
     records = {
@@ -73,6 +72,8 @@ async def backup_chat_history_to_turso(client: Client, chat_id: str | int, hours
         info = parse_msg(message, silent=True)
         if info["mid"] in saved_mids:
             continue
+        if not fine_grained_check(info):
+            continue
         if info["time"] < begin_time:
             break
         records = {
@@ -130,11 +131,7 @@ async def upload_exported_history_to_turso(client: Client, path: str | Path | No
         data = json.load(f)
         logger.info(f"Found {len(data['messages'])} messages in json file")
     mtypes = {
-        "text": "text",
-        "photo": "photo",
-        "animation": "animation",
         "audio_file": "audio",
-        "sticker": "sticker",
         "voice_message": "voice",
         "video_message": "video",
         "video_file": "video",
@@ -159,7 +156,12 @@ async def upload_exported_history_to_turso(client: Client, path: str | Path | No
                 info["media_type"] = "photo"
             if "video/" in info.get("mime_type", ""):
                 info["media_type"] = "video_file"
-
+        mtype = info.get("media_type", "text")
+        content = parse_text(info.get("text", []))
+        urls = parse_urls(info.get("text_entities", []))
+        # fine-grained check requires key: ["cid", "mtype", "text", "entity_urls"]
+        if not fine_grained_check({"cid": data["id"], "mtype": mtype, "text": content, "entity_urls": urls}):
+            continue
         dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
         uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
         user = info["from"] or info["from_id"].removeprefix("user").removeprefix("channel")
@@ -167,15 +169,14 @@ async def upload_exported_history_to_turso(client: Client, path: str | Path | No
             user = ""
             uid = 1
 
-        content = parse_text(info.get("text", []))
         records = {
             "mid": info["id"],
-            "mtype": mtypes[info.get("media_type", "text")],
+            "mtype": mtypes.get(mtype, mtype),
             "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
             "fullname": user,
-            "content": parse_text(info.get("text", [])),
+            "content": content,
             "filename": info.get("file_name", ""),
-            "urls": parse_urls(info.get("text_entities", [])),
+            "urls": urls,
             "reply": info.get("reply_to_message_id"),
             "mime": info.get("mime_type", ""),
             "user": user.replace(" ", ""),
src/history/utils.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import contextlib
+import os
 import string
 
 from loguru import logger
@@ -12,7 +13,7 @@ from config import DB, HISTORY, TID, cache
 from database.turso import turso_list_tables
 from messages.sender import send2tg
 from permission import slim_cid
-from utils import to_int
+from utils import find_url, to_int, true
 
 TURSO_KWARGS: dict = {
     "db_name": HISTORY.TURSO_DATABASE,
@@ -22,9 +23,66 @@ TURSO_KWARGS: dict = {
 }
 
 
+@cache.memoize(ttl=0)
+def check_save_history(ctype: str, cid: int | str) -> bool:
+    # ruff: noqa: SIM103
+    cid = slim_cid(cid)
+    if true(os.getenv(f"HISTORY_IGNORE_{cid}")):
+        return False
+    if true(os.getenv(f"HISTORY_INCLUDE_{cid}")):
+        return True
+    if cid in HISTORY.IGNORE_CHATS.split(","):
+        return False
+    if cid in HISTORY.INCLUDE_CHATS.split(","):
+        return True
+    if ctype == "PRIVATE":
+        if str(HISTORY.INCLUDE_PRIVATES).lower() == "all" or cid in HISTORY.INCLUDE_PRIVATES.split(","):
+            return True
+        return False
+    if ctype == "BOT":
+        if str(HISTORY.INCLUDE_BOTS).lower() == "all" or cid in HISTORY.INCLUDE_BOTS.split(","):
+            return True
+        return False
+    if ctype in ["GROUP", "SUPERGROUP"]:
+        if str(HISTORY.INCLUDE_GROUPS).lower() == "all" or cid in HISTORY.INCLUDE_GROUPS.split(","):
+            return True
+        return False
+    if ctype == "CHANNEL":
+        if str(HISTORY.INCLUDE_CHANNELS).lower() == "all" or cid in HISTORY.INCLUDE_CHANNELS.split(","):
+            return True
+        return False
+    return False
+
+
+def fine_grained_check(info: dict) -> bool:
+    """由于有些对话不需要保存所有类型的聊天历史, 这里检查是否需要跳过.
+
+    这种细粒度的检查, 仅支持通过环境变量设置.
+    目前支持:
+        HISTORY_{cid}_MUST_MTYPE:    必须为指定的消息类型
+        HISTORY_{cid}_MUST_HAVE_TEXT:  必须有文字的消息
+        HISTORY_{cid}_SKIP_URL: 跳过包含链接的消息
+        HISTORY_{cid}_SKIP_KEYWORDS: 跳过包含关键词的消息 (其中关键词为逗号分隔的字符串)
+    例如: 对于`chat_id = 1234` 的对话, 不需要保存没有文字的消息
+    """
+    # ruff: noqa: SIM103
+    cid = slim_cid(info["cid"])
+    if (mtype := os.getenv(f"HISTORY_{cid}_MUST_MTYPE")) and info["mtype"].lower() not in mtype.lower():
+        return False
+    if true(os.getenv(f"HISTORY_{cid}_MUST_HAVE_TEXT")) and not info["text"]:
+        return False
+    if true(os.getenv(f"HISTORY_{cid}_SKIP_URL")) and (find_url(info["text"]) or info.get("entity_urls")):
+        return False
+    if os.getenv(f"HISTORY_{cid}_SKIP_KEYWORDS"):
+        keywords = [x.strip() for x in os.environ[f"HISTORY_{cid}_SKIP_KEYWORDS"].split(",") if x.strip()]
+        if any(x in info["text"] for x in keywords):
+            return False
+    return True
+
+
 async def chat_info(client: Client, chat_id: int) -> Chat:
-    if cache.get(f"chat-info-{chat_id}"):
-        return cache.get(f"chat-info-{chat_id}")
+    if cache.get(f"chat-info-{slim_cid(chat_id)}"):
+        return cache.get(f"chat-info-{slim_cid(chat_id)}")
     if chat_id == 1:
         return Chat(id=1)
     try:
@@ -34,7 +92,7 @@ async def chat_info(client: Client, chat_id: int) -> Chat:
         return await chat_info(client, int(f"-100{chat_id}"))
     except Exception:
         chat = Chat(id=1)
-    cache.set(f"chat-info-{chat_id}", chat, ttl=3600)  # cache for 1 hour
+    cache.set(f"chat-info-{slim_cid(chat_id)}", chat, ttl=3600)  # cache for 1 hour
     return chat
 
 
src/permission.py
@@ -7,7 +7,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ENABLE, HISTORY, TID, cache
+from config import ENABLE, TID, cache
 from utils import i_am_bot, to_int, true
 
 
@@ -189,36 +189,5 @@ def check_service(cid: int | str, ctype: str) -> dict:
     return permission
 
 
-@cache.memoize(ttl=0)
-def check_save_history(ctype: str, cid: int | str) -> bool:
-    # ruff: noqa: SIM103
-    cid = slim_cid(cid)
-    if true(os.getenv(f"HISTORY_IGNORE_{cid}")):
-        return False
-    if true(os.getenv(f"HISTORY_INCLUDE_{cid}")):
-        return True
-    if cid in HISTORY.IGNORE_CHATS.split(","):
-        return False
-    if cid in HISTORY.INCLUDE_CHATS.split(","):
-        return True
-    if ctype == "PRIVATE":
-        if str(HISTORY.INCLUDE_PRIVATES).lower() == "all" or cid in HISTORY.INCLUDE_PRIVATES.split(","):
-            return True
-        return False
-    if ctype == "BOT":
-        if str(HISTORY.INCLUDE_BOTS).lower() == "all" or cid in HISTORY.INCLUDE_BOTS.split(","):
-            return True
-        return False
-    if ctype in ["GROUP", "SUPERGROUP"]:
-        if str(HISTORY.INCLUDE_GROUPS).lower() == "all" or cid in HISTORY.INCLUDE_GROUPS.split(","):
-            return True
-        return False
-    if ctype == "CHANNEL":
-        if str(HISTORY.INCLUDE_CHANNELS).lower() == "all" or cid in HISTORY.INCLUDE_CHANNELS.split(","):
-            return True
-        return False
-    return False
-
-
 def slim_cid(cid: int | str) -> str:
     return str(cid).strip().removeprefix("-100")