Commit 6e8ce40

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-07-21 02:45:14
chore(history): improve query accuracy
1 parent c81953f
Changed files (2)
src
src/history/query.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import re
-import string
 from io import BytesIO
 
 from loguru import logger
@@ -9,10 +8,10 @@ from pyrogram.client import Client
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 from pyrogram.types import Message
 
-from config import PREFIX, TZ, cache, cutter
+from config import PREFIX, TZ, cache
 from database.turso import turso_exec, turso_parse_resp
 from history.turso import get_turso_chatinfo, get_user, save_chatinfo_to_turso
-from history.utils import TURSO_KWARGS, check_save_history, get_chat, is_admin, list_chat_ids
+from history.utils import TURSO_KWARGS, check_save_history, filter_response, generate_query, get_chat, is_admin, list_chat_ids
 from llm.utils import convert_html
 from messages.parser import parse_chat, parse_msg
 from messages.progress import modify_progress
@@ -182,11 +181,17 @@ def parse_queries(texts: str, qtype: str) -> tuple[str, str, str, str, str]:
 
 
 async def query_turso(client: Client, cinfo: dict[str, str], match_time: str, user: str, keyword: str) -> dict:
-    """Query chat history from Turso."""
-    keyword = keyword.replace(",", " ")  # comma can not be used in query
-    segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
-    texts_to_match = " ".join(segmented)
-    sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{texts_to_match}'"
+    """Query chat history from Turso.
+
+    由于LIKE查询会扫描整个表, 速度较慢, 而且会快速消耗读取数量配额, 因此我们使用FTS5搜索.
+    由于FTS5不支持中文匹配, 且远端数据库不支持icu分词器, 所以在插入文本时手动进行了分词 (基于`cutword`库)
+    分词结果可能并不准确, 但可以满足大部分需求.
+    我们还需要对返回的结果进行进一步过滤, 以精确匹配.
+    """
+    search_query = generate_query(keyword)
+    if not search_query:
+        return {}
+    sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{search_query}'"
     if match_time:
         begin = "1970-01-01 00:00:00"
         end = nowstr(TZ)
@@ -210,10 +215,11 @@ async def query_turso(client: Client, cinfo: dict[str, str], match_time: str, us
     sql += " ORDER BY T.mid DESC"
     logger.info(sql)
     resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, retry=2, **TURSO_KWARGS)
+    filterd = filter_response(turso_parse_resp(resp), keyword)
     full_texts = ""
     texts = ""  # long message will be trimmed
     count = 0
-    for row in turso_parse_resp(resp):
+    for row in filterd:
         url = f"https://t.me/{cinfo['chandle']}/{row['mid']}" if cinfo["chandle"] else f"https://t.me/c/{cinfo['cid']}/{row['mid']}"
         username = row["fullname"] or "消息链接"
         emoji = MTYPE_EMOJI[row["mtype"]] if row["mtype"] != "text" else ""
src/history/utils.py
@@ -2,13 +2,15 @@
 # -*- coding: utf-8 -*-
 import contextlib
 import os
+import re
+import string
 
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.errors import PeerIdInvalid
 from pyrogram.types import Chat, Message, User
 
-from config import DB, HISTORY, TID, cache
+from config import DB, HISTORY, TID, cache, cutter
 from database.turso import turso_exec, turso_parse_resp
 from messages.sender import send2tg
 from others.emoji import CTYPE_EMOJI
@@ -148,3 +150,43 @@ async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) ->
                     user = member.user
                     break
     return user
+
+
+def generate_query(keyword: str) -> str:
+    """Generate search query based on keyword."""
+    # ruff: noqa: RUF001
+    punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
+    for punc in punctuation:  # remove pucntuation
+        keyword = keyword.replace(punc, " ")
+    keyword = keyword.replace("(", "(")
+    keyword = keyword.replace(")", ")")
+    # remove consecutive whitespace
+    while "  " in keyword:
+        keyword = keyword.replace("  ", " ")
+    # remove leading and trailing whitespace
+    keyword = keyword.strip()
+    segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
+    final = []
+    length = len(segmented)
+    for i, word in enumerate(segmented):
+        if word.strip() in ["OR", "AND", "NOT", "(", ")"]:
+            final.append(word)
+            continue
+        final.append(f'"{word}"')
+        if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
+            final.append("AND")
+    return " ".join(final)
+
+
+def filter_response(resp: list[dict], keyword: str) -> list[dict]:
+    if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]):  # keyword is a SQL query
+        return resp
+    # remove consecutive whitespace
+    while "  " in keyword:
+        keyword = keyword.replace("  ", " ")
+    # remove leading and trailing whitespace
+    keyword = keyword.strip()
+
+    re_keyword = keyword.replace(" ", "(.*?)")
+    pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
+    return [row for row in resp if pattern.search(row["content"])]