Commit 88d4d0d

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-09 18:05:55
feat(history): add `/history` query feature
1 parent ad2629f
src/database/turso.py
@@ -130,6 +130,11 @@ async def turso_exec(
     if not db_url:
         return {}
     headers = {"authorization": f"Bearer {group_token}", "content-type": "application/json"}
+
+    for stmt in statements:
+        if (sql := stmt.get("sql")) and not sql.endswith(";"):
+            stmt["sql"] += ";"
+
     if statements[-1] != {"type": "close"}:
         statements.append({"type": "close"})
     if not silent:
src/history/query.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import re
+from io import BytesIO
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message
+
+from config import PREFIX
+from database.turso import turso_exec
+from history.turso import get_table_name
+from history.utils import TURSO_KWARGS, get_uid_by_username, is_admin, list_chat_ids, mtype_emoji
+from llm.utils import convert_html
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import blockquote, equal_prefix, smart_split, startswith_prefix
+from permission import check_save_history, slim_cid
+from publish import publish_telegraph
+
+HELP = f"""🗣**查询当前对话聊天记录**
+`/hist` 使用说明:
+1.`/hist + 关键词`
+2.`/hist + 日期 + 关键词`
+3.`/hist + @用户名 + 关键词`
+4.`/hist + 日期 + @用户名 + 关键词` (日期需放在最前面)
+4.`/hist + 日期 + @用户名 + 关键词` (日期需放在最前面)
+示例:
+{BLOCKQUOTE_EXPANDABLE_DELIM}`/hist 你好`: 查询包含“你好”关键词的记录
+`/hist 2025-01-01 你好`: 查询2025-01-01日包含“你好”的记录
+`/hist @张三 你好`: 查询用户【张三】包含“你好”的记录
+`/hist 2025 @张三 你好`: 查询2025年用户【张三】包含“你好”的记录
+
+注意:
+- 用户名和关键词需要区分大小写
+- 用户名可以为昵称 (Name)、用户名 (@username)
+- 如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
+{BLOCKQUOTE_EXPANDABLE_END_DELIM}
+`/history` 使用说明:
+查询所有对话的聊天记录
+但出于隐私考虑, 本命令会限制使用权限
+`/history + #ChatID` + [日期]+[用户名]+[关键词]
+`/history --list`: 列出所有ChatID
+"""
+
+
+async def query_chat_history(client: Client, message: Message, **kwargs):
+    info = parse_msg(message, silent=True)
+    admin_call = is_admin(info["uid"], info["handle"])
+    if not check_save_history(info["ctype"], info["cid"]) and not admin_call:  # save history is disabled for this chat
+        return
+    if not startswith_prefix(info["text"], prefix=PREFIX.HISTORY):
+        return
+    if equal_prefix(info["text"], prefix=PREFIX.HISTORY):
+        await send2tg(client, message, texts=HELP, **kwargs)
+        return
+    if startswith_prefix(info["text"], prefix="/history") and not admin_call:
+        await send2tg(client, message, texts="⚠️您无权使用此命令", **kwargs)
+        return
+
+    if info["text"].strip() == "/history --list":
+        await list_chat_ids(client, message)
+        return
+    qtype = "history" if startswith_prefix(info["text"].strip(), prefix="/history") else "hist"
+
+    chat_id, match_time, user, keyword, error = parse_queries(info["text"], qtype)
+    if error:
+        await send2tg(client, message, texts=error, **kwargs)
+        return
+    if qtype == "/hist":
+        chat_id = slim_cid(info["cid"])
+    table_name = await get_table_name(client, chat_id)
+    chat_title = "".join(table_name.split("-")[1:])
+
+    caption = "📖**查询聊天记录**:"
+    caption += f"\n🆔会话: {chat_title}"
+    caption += f"\n🕒日期: {match_time}"
+    caption += f"\n👤用户: {user}"
+    caption += f"\n🔤关键词: {keyword}"
+    status_msg = (await send2tg(client, message, texts=caption, **kwargs))[0]
+    kwargs["progress"] = status_msg
+
+    results = await query_turso(client, table_name, match_time, user, keyword)
+    texts = results.get("texts", "")
+    count = results.get("count", 0)
+
+    if not texts:
+        await modify_progress(text=caption + "\n⚠️未匹配任何记录", force_update=True, **kwargs)
+        return
+
+    if len(texts) < 20480 and len(await smart_split(texts)) == 1:
+        await modify_progress(message=status_msg, text=blockquote(texts), force_update=True, **kwargs)
+        return
+
+    caption += f"\n#️⃣消息数: {count}"
+    # less than 100,000, add instant view
+    if len(texts) < 1000000 and (telegraph_url := await publish_telegraph(title=f"【{chat_title}】{user}{match_time} {keyword}", html=convert_html(texts), author=user or chat_title)):
+        caption += f"\n⚡️[即时预览]({telegraph_url})"
+    # send as txt
+    with BytesIO(texts.encode("utf-8")) as f:
+        await client.send_document(info["cid"], f, file_name=f"【{chat_title}】{user}{match_time} {keyword}.txt", caption=caption)
+
+    await modify_progress(message=status_msg, del_status=True, **kwargs)
+
+
+def parse_queries(texts: str, qtype: str) -> tuple[str, str, str, str, str]:
+    """Parse from users' query.
+
+    Returns:
+        chat_id, match_time, user, keyword, error
+    """
+    # ruff: noqa: SIM114
+    chat_id = ""
+    match_time = ""
+    user = ""
+    keyword = ""
+    error = ""
+    texts = re.sub(r"^/history", "/hist", texts, count=1)  # unify prefix
+    # #chat_id
+    if matched := re.match(r"^/hist\s+#(-100)?(\d+)(\s+)?", texts):
+        chat_id = matched.group(2)
+    texts = re.sub(rf"^/hist\s+#(-100)?{chat_id}", "", texts).lstrip()  # remove prefix + #chat_id
+
+    # 2025-01-01
+    if matched := re.match(r"(\d{4}-\d{2}-\d{2})(\s+)?", texts):
+        match_time = matched.group(1)
+    # 2025-01
+    elif matched := re.match(r"(\d{4}-\d{2})(\s+)?", texts):
+        match_time = matched.group(1)
+    # 2025
+    elif matched := re.match(r"(\d{4})(\s+)?", texts):
+        match_time = matched.group(1)
+    texts = re.sub(rf"^{match_time}", "", texts).lstrip()  # remove date
+
+    # @张三 你好
+    # @张三
+    if matched := re.match(r"^@(\w+)(\s+)?", texts):
+        user = matched.group(1)
+    keyword = re.sub(rf"^@{user}", "", texts).lstrip()  # remove user
+
+    # error handling
+    if not keyword:
+        error = f"查询格式有误, 必须包含 `关键词`\n请发送 `{PREFIX.HISTORY}` 命令查看帮助"
+
+    if qtype == "hist":
+        if not any((match_time, user, keyword)):
+            error = f"查询格式有误, 请发送 `{PREFIX.HISTORY}` 命令查看帮助"
+        if chat_id:
+            error = "`/hist` 命令不支持指定ChatID, 仅支持查询当前对话聊天记录"
+
+    if qtype == "history":
+        if not any((chat_id, match_time, user, keyword)):
+            error = f"查询格式有误, 请发送 `{PREFIX.HISTORY}` 命令查看帮助"
+        if not chat_id:
+            error = "`/history` 命令需要指定 ChatID\n`/history --list`: 列出所有ChatID"
+
+    return chat_id, match_time, user, keyword, error
+
+
+async def query_turso(client: Client, table_name: str, match_time: str, user: str, keyword: str) -> dict:
+    """Query chat history from Turso."""
+    # ruff: noqa: S608
+    cid = int(table_name.split("-")[0])
+    sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{table_name}' AS T JOIN fts_{cid} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{keyword}'"
+    if match_time:
+        if len(match_time) == 4:  # 2025
+            begin = f"{match_time}-01-01 00:00:00"
+            end = f"{match_time}-12-31 23:59:59"
+        elif len(match_time) == 7:  # 2025-01
+            begin = f"{match_time}-01 00:00:00"
+            end = f"{match_time}-31 23:59:59"
+        elif len(match_time) == 10:  # 2025-01-01
+            begin = f"{match_time} 00:00:00"
+            end = f"{match_time} 23:59:59"
+        sql += f" AND T.time >= '{begin}' AND T.time <= '{end}'"
+    if user:
+        # 由于username可以修改, 我们优先使用UID进行匹配
+        if uid := await get_uid_by_username(client, user):
+            sql += f" AND T.uid = {uid}"
+        else:
+            sql += f" AND T.user = '{user}'"
+    logger.info(sql)
+    resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, retry=2, **TURSO_KWARGS)
+
+    # parse turso response
+    cols = glom(resp, "results.0.response.result.cols", default=[])
+    rows = glom(resp, "results.0.response.result.rows", default=[])
+    texts = ""
+    count = 0
+    for row in rows:
+        row_info = {col["name"]: x["value"] for x, col in zip(row, cols, strict=True)}
+        url = f"https://t.me/c/{cid}/{row_info['mid']}"
+        texts += f"👤[{row_info['fullname']}]({url}) {row_info['time']}{mtype_emoji(row_info['mtype'])}:\n{row_info['content']}\n\n"
+        count += 1
+    return {"texts": texts.strip(), "count": count}
src/history/turso.py
@@ -8,11 +8,11 @@ from zoneinfo import ZoneInfo
 from glom import Coalesce, flatten, glom
 from loguru import logger
 from pyrogram.client import Client
-from pyrogram.errors import PeerIdInvalid
-from pyrogram.types import Chat, Message
+from pyrogram.types import Message
 
-from config import DB, DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
+from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
 from database.turso import insert_statement, turso_create_table, turso_exec, turso_list_tables
+from history.utils import TURSO_KWARGS, chat_info
 from messages.parser import parse_msg
 from permission import check_save_history
 from utils import i_am_bot, nowdt
@@ -21,12 +21,6 @@ from utils import i_am_bot, nowdt
 
 DB_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, fullname TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, user TEXT, uid INTEGER, segmented TEXT"
 INDEX_NAMES = ["time", "user", "uid"]
-TURSO_KWARGS: dict = {
-    "db_name": HISTORY.TURSO_DATABASE,
-    "username": HISTORY.TURSO_USERNAME or DB.TURSO_USERNAME,
-    "api_token": HISTORY.TURSO_API_TOKEN or DB.TURSO_API_TOKEN,
-    "group_token": HISTORY.TURSO_GROUP_TOKEN or DB.TURSO_GROUP_TOKEN,
-}
 
 
 async def sync_history_to_turso(client: Client, message: Message) -> None:
@@ -203,22 +197,6 @@ async def upload_exported_history_to_turso(client: Client, path: str | Path | No
             logger.success(f"Synced {num_success} messages to Turso, {min(sync_ids)} -> {max(sync_ids)}")
 
 
-async def chat_info(client: Client, chat_id: int) -> Chat:
-    if cache.get(f"chat-info-{chat_id}"):
-        return cache.get(f"chat-info-{chat_id}")
-    if chat_id == 1:
-        return Chat(id=1)
-    try:
-        logger.debug(f"Getting chat info for {chat_id}")
-        chat = await client.get_chat(int(chat_id))
-    except PeerIdInvalid:
-        return await chat_info(client, int(f"-100{chat_id}"))
-    except Exception:
-        chat = Chat(id=1)
-    cache.set(f"chat-info-{chat_id}", chat, ttl=3600)  # cache for 1 hour
-    return chat
-
-
 async def get_table_name(client: Client, chat_id: str | int) -> str:
     """Get table name by chat id."""
     if cache.get(f"tablename-{chat_id}"):
src/history/utils.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import string
+
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.errors import PeerIdInvalid
+from pyrogram.types import Chat, Message, User
+
+from config import DB, HISTORY, TID, cache
+from database.turso import turso_list_tables
+from messages.sender import send2tg
+from permission import slim_cid
+from utils import to_int
+
+TURSO_KWARGS: dict = {
+    "db_name": HISTORY.TURSO_DATABASE,
+    "username": HISTORY.TURSO_USERNAME or DB.TURSO_USERNAME,
+    "api_token": HISTORY.TURSO_API_TOKEN or DB.TURSO_API_TOKEN,
+    "group_token": HISTORY.TURSO_GROUP_TOKEN or DB.TURSO_GROUP_TOKEN,
+}
+
+
+async def chat_info(client: Client, chat_id: int) -> Chat:
+    if cache.get(f"chat-info-{chat_id}"):
+        return cache.get(f"chat-info-{chat_id}")
+    if chat_id == 1:
+        return Chat(id=1)
+    try:
+        logger.debug(f"Getting chat info for {chat_id}")
+        chat = await client.get_chat(int(chat_id))
+    except PeerIdInvalid:
+        return await chat_info(client, int(f"-100{chat_id}"))
+    except Exception:
+        chat = Chat(id=1)
+    cache.set(f"chat-info-{chat_id}", chat, ttl=3600)  # cache for 1 hour
+    return chat
+
+
+async def list_chat_ids(client: Client, message: Message):
+    table_names = await turso_list_tables(**TURSO_KWARGS, silent=True)
+    msg = ""
+    for table_name in table_names:
+        if table_name.startswith("fts_"):
+            continue
+        cid, ctitle = table_name.split("-")
+        msg += f"`/history #{cid}`: {ctitle}\n"
+    await send2tg(client, message, texts=msg)
+
+
+def is_admin(uid: int, handle: str) -> bool:
+    for admin in [x.strip() for x in TID.ADMIN.split(",") if x.strip()]:
+        if admin.startswith("@") and admin[1:].lower() == handle.lower():
+            return True
+        if slim_cid(admin) == slim_cid(uid):
+            return True
+    return False
+
+
+async def get_uid_by_username(client: Client, username: str) -> int:
+    """Get Telegram user id by username.
+
+    Support formats of `username`:
+        handle (a-z, A-Z, 0-9, _)
+    """
+    if cache.get(f"get_uid_by_username-{username}"):
+        return cache.get(f"get_uid_by_username-{username}")
+    if all(x in list(string.digits) + list(string.ascii_letters) for x in username):
+        logger.debug(f"Getting uid by username: {username}")
+        with contextlib.suppress(Exception):
+            user = await client.get_users(to_int(username))
+            if isinstance(user, User):
+                cache.set(f"get_uid_by_username-{username}", user.id, ttl=0)
+                return user.id
+    cache.set(f"get_uid_by_username-{username}", 0, ttl=0)
+    return 0
+
+
+def mtype_emoji(mtype: str) -> str:
+    emojis = {
+        "audio": "🎧",
+        "document": "📔",
+        "photo": "🏞",
+        "sticker": "🎨",
+        "video": "🎥",
+        "video_note": "🎥",
+        "animation": "✨",
+        "voice": "🎤",
+        "web_page": "🌐",
+    }
+    if mtype in emojis:
+        return "|" + emojis[mtype]
+    return ""
src/config.py
@@ -43,6 +43,7 @@ class ENABLE:  # see fine-grained permission in `src/permission.py`
     GPT = os.getenv("ENABLE_GPT", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     INSTAGRAM = os.getenv("ENABLE_INSTAGRAM", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     OCR = os.getenv("ENABLE_OCR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    HISTORY = os.getenv("ENABLE_HISTORY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     PRICE = os.getenv("ENABLE_PRICE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     SEARCH_YOUTUBE = os.getenv("ENABLE_SEARCH_YOUTUBE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     SEARCH_GOOGLE = os.getenv("ENABLE_SEARCH_GOOGLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -72,19 +73,20 @@ class PREFIX:
     AUDIO = os.getenv("PREFIX_AUDIO", "/audio").lower()
     CONVERT = os.getenv("PREFIX_CONVERT", "/convert").lower()  # convert image file to photo
     GPT = os.getenv("PREFIX_GPT", "/ai,/gpt,/gemini,/ds,/qwen,/doubao,/grok").lower()
-    SUBTITLE = os.getenv("PREFIX_SUBTITLE", "/subtitle,/sub").lower()
-    WGET = os.getenv("PREFIX_WGET", "/wget,/curl").lower()
+    SUBTITLE = os.getenv("PREFIX_SUBTITLE", "/subtitle, /sub").lower()
+    WGET = os.getenv("PREFIX_WGET", "/wget, /curl").lower()
     OCR = os.getenv("PREFIX_OCR", "/ocr").lower()
     PRICE = os.getenv("PREFIX_PRICE", "/price").lower()  # unify crypto, stock
     CRYPTO = os.getenv("PREFIX_CRYPTO", "/crypto").lower()  # crypto only
     STOCK = os.getenv("PREFIX_STOCK", "/stock").lower()  # stock only
     COMBINATION = os.getenv("PREFIX_COMBINATION", "/combine").lower()
     VOICE = os.getenv("PREFIX_VOICE", "/voice").lower()
-    SEARCH_YOUTUBE = os.getenv("PREFIX_SEARCH_YOUTUBE", "/youtube,/ytb").lower()
+    SEARCH_YOUTUBE = os.getenv("PREFIX_SEARCH_YOUTUBE", "/youtube, /ytb").lower()
     SEARCH_GOOGLE = os.getenv("PREFIX_SEARCH_GOOGLE", "/google").lower()
     GENIMG = os.getenv("PREFIX_GENIMG", "/gen").lower()
     DANMU = os.getenv("PREFIX_DANMU", "/danmu").lower()
     FAYAN = os.getenv("PREFIX_FAYAN", "/fa").lower()
+    HISTORY = "/history, /hist"
 
 
 class API:
@@ -175,7 +177,7 @@ class COOKIE:  # See: https://github.com/easychen/CookieCloud
 
 
 class TID:  # see more TID usecase in `src/permission.py`
-    ADMIN = os.getenv("TID_ADMIN", "")
+    ADMIN = os.getenv("TID_ADMIN", "")  # comma separated userid or @username
     # back up ytdlp audio if the user does not request it
     CHANNEL_YTDLP_BACKUP = os.getenv("TID_CHANNEL_YTDLP_BACKUP", "me")
     DAILY_SUMMARY = os.getenv("TID_DAILY_SUMMARY", "{}")  # {"source-chat-id": "target-chat-id"}, e.g. '{"-1001234567890": "-1009876543210"}'
src/handler.py
@@ -11,6 +11,7 @@ from bridge.ocr import send_to_ocr_bridge
 from config import ENABLE, PREFIX, PROXY
 from danmu.entrypoint import query_danmu
 from database.database import del_db
+from history.query import query_chat_history
 from history.sync import sync_chat_history
 from llm.gpt import gpt_response
 from llm.summary import ai_summary
@@ -51,6 +52,7 @@ async def handle_utilities(
     save_history: bool = True,
     google: bool = True,
     ocr: bool = True,
+    history: bool = True,
     price: bool = True,
     subtitle: bool = True,
     summary: bool = True,
@@ -77,6 +79,7 @@ async def handle_utilities(
         save_history (bool, optional): Enable save chat history. Defaults to True.
         google (bool, optional): Enable Google Search. Defaults to True.
         ytb (bool, optional): Enable YouTube Search. Defaults to True.
+        history (bool, optional): Enable History Search. Defaults to True.
         subtitle (bool, optional): Enable YouTube subtitle. Defaults to True.
         wget (bool, optional): Enable WGET. Defaults to True.
         ocr (bool, optional): Enable OCR. Defaults to True.
@@ -113,6 +116,8 @@ async def handle_utilities(
         await send_to_ocr_bridge(client, message, **kwargs)  # /ocr
     if price:
         await get_asset_price(client, message, **kwargs)  # /price
+    if history:
+        await query_chat_history(client, message, **kwargs)  # /history
     if summary:
         await ai_summary(client, message, **kwargs)  # /summary
     if danmu:
@@ -338,6 +343,8 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefix: str):
         msg += f"\n💵**查询价格**: `{PREFIX.PRICE}` + Symbol"
     if permission["subtitle"]:
         msg += f"\n📃**提取字幕**: `{PREFIX.SUBTITLE}` + B站或油管链接"
+    if permission["history"]:
+        msg += f"\n🗣**查询聊天记录**: 发送 `{PREFIX.HISTORY}` 查看详细教程"
     if permission["wget"]:
         msg += f"\n⏬**下载文件**: `{PREFIX.WGET}` + URL"
     if permission["ytb"]:
src/permission.py
@@ -124,6 +124,7 @@ def check_service(cid: int | str, ctype: str) -> dict:
         "wechat": True,
         "reddit": True,
         "ytdlp": True,
+        "history": True,
     }
 
     if ctype == "PRIVATE":
@@ -171,6 +172,8 @@ def check_service(cid: int | str, ctype: str) -> dict:
         permission["raw_img"] = False
     if not ENABLE.QUERY_DANMU:
         permission["danmu"] = False
+    if not ENABLE.HISTORY:
+        permission["history"] = False
 
     """
     Set specific service
@@ -194,6 +197,7 @@ def check_service(cid: int | str, ctype: str) -> dict:
     return permission
 
 
+@cache.memoize(ttl=0)
 def check_save_history(ctype: str, cid: int | str) -> bool:
     # ruff: noqa: SIM103
     cid = slim_cid(cid)
src/utils.py
@@ -286,6 +286,8 @@ def parse_time(timestr: str) -> dict[str, int]:
     {"year": int, "month": int, "day": int, "hour": int, "minute": int, "second": int}
     """
     res = {"year": 0, "month": 0, "day": 0, "hour": 0, "minute": 0, "second": 0}
+    if not timestr:
+        return {}
     if len(timestr) not in [4, 6, 7, 8, 10, 14, 15, 19]:
         logger.warning(f"Invalid time format: {timestr}")
         return res