Commit 7745834

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-08-26 10:52:56
feat(history): allow empty keyword query
1 parent 7f322a4
Changed files (2)
src
src/history/query.py
@@ -14,16 +14,7 @@ from database.d1 import query_d1
 from database.turso import turso_exec, turso_parse_resp
 from history.d1 import get_d1_chatinfo, save_chatinfo_to_d1
 from history.turso import get_turso_chatinfo, save_chatinfo_to_turso
-from history.utils import (
-    TURSO_KWARGS,
-    check_save_history,
-    filter_response,
-    generate_query,
-    get_chat,
-    get_user_from_chat,
-    is_admin,
-    list_chat_ids,
-)
+from history.utils import TURSO_KWARGS, check_save_history, filter_response, get_chat, get_user_from_chat, is_admin, keyword_query, list_chat_ids
 from llm.utils import convert_html
 from messages.parser import parse_chat, parse_msg
 from messages.progress import modify_progress
@@ -120,7 +111,6 @@ async def query_chat_history(client: Client, message: Message, **kwargs):
     if not texts:
         await modify_progress(text=caption + "\n⚠️未匹配任何记录", force_update=True, **kwargs)
         return
-
     if len(texts) < 20480 and len(await smart_split(texts)) == 1:
         await modify_progress(message=status_msg, text=blockquote(texts), force_update=True, **kwargs)
         return
@@ -172,10 +162,6 @@ def parse_queries(texts: str, qtype: str) -> tuple[str, str, str, str, str]:
         user = matched.group(1)
     keyword = re.sub(rf"^@{user}", "", texts).lstrip()  # remove user
 
-    # error handling
-    if not keyword:
-        error = f"查询格式有误, 必须包含 `关键词`\n请发送 `{PREFIX.HISTORY}` 命令查看帮助"
-
     if qtype == "hist":
         if not any((match_time, user, keyword)):
             error = f"查询格式有误, 请发送 `{PREFIX.HISTORY}` 命令查看帮助"
@@ -206,10 +192,10 @@ async def query_history(
     分词结果可能并不准确, 但可以满足大部分需求.
     我们还需要对返回的结果进行进一步过滤, 以精确匹配.
     """
-    search_query = generate_query(keyword)
-    if not search_query:
-        return {}
-    sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{search_query}'"
+    conditions = []
+    if keyword:
+        conditions.append(keyword_query(keyword))
+
     if match_time:
         begin = "1970-01-01 00:00:00"
         end = nowstr(TZ)
@@ -222,16 +208,30 @@ async def query_history(
         elif len(match_time) == 10:  # 2025-01-01
             begin = f"{match_time} 00:00:00"
             end = f"{match_time} 23:59:59"
-        sql += f" AND T.time >= '{begin}' AND T.time <= '{end}'"
+        conditions.append(f"T.time >= '{begin}' AND T.time <= '{end}'")
     if user:
         # 由于username可以修改, 我们优先使用UID进行匹配
         real_cid = cinfo["chandle"] if cinfo.get("chandle") else cinfo["cid"] if cinfo["ctype"] in ["BOT", "PRIVATE"] else f"-100{cinfo['cid']}"
         if uid := await get_uid_by_username(client, real_cid, user, engine):
-            sql += f" AND T.uid = {uid}"
+            conditions.append(f"T.uid = {uid}")
         else:
-            sql += f" AND T.user = '{user}'"
-    sql += " ORDER BY T.mid DESC"
+            conditions.append(f"T.user = '{user}'")
+    limit = 200
+    if conditions:
+        condition = " AND ".join(conditions)
+        sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE {condition}  ORDER BY T.mid DESC"
+    else:
+        sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid ORDER BY T.mid DESC"
+    if keyword or match_time:
+        limit = 99999
+    sql += f" LIMIT {limit}"
     logger.info(sql)
+    limit_to_single_msg = False
+    if not any((match_time, user, keyword)):
+        limit_to_single_msg = True
+    elif user and (not any((match_time, keyword))):  # only user
+        limit_to_single_msg = True
+
     if engine == "turso":
         resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, retry=2, **TURSO_KWARGS)
         filterd = filter_response(turso_parse_resp(resp), keyword)
@@ -255,8 +255,16 @@ async def query_history(
             end = min(len(content), idx + len(keyword) + 45)
             end_suffix = "..." if end != len(content) else ""
             content = f"{begin_prefix}{row['content'][begin:end]}{end_suffix}"
-        texts += f"\n👤[{username}]({url}) {row['time']}{emoji}:\n{content}\n"
-        count += 1
+        entry = f"\n👤[{username}]({url}) {row['time']}{emoji}:\n{content}\n"
+        if limit_to_single_msg:
+            if len(await smart_split(texts + entry)) == 1:
+                texts += entry
+                count += 1
+            else:
+                break
+        else:
+            texts += entry
+            count += 1
     return {"texts": texts.strip(), "full_texts": full_texts.strip(), "count": count}
 
 
src/history/utils.py
@@ -166,7 +166,7 @@ async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) ->
     return user
 
 
-def generate_query(keyword: str) -> str:
+def keyword_query(keyword: str) -> str:
     """Generate search query based on keyword."""
     # ruff: noqa: RUF001
     punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
@@ -189,12 +189,16 @@ def generate_query(keyword: str) -> str:
         final.append(f'"{word}"')
         if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
             final.append("AND")
-    return " ".join(final)
+    return f"""FTS.segmented MATCH '{" ".join(final)}'"""
 
 
 def filter_response(resp: list[dict], keyword: str) -> list[dict]:
+    """Filter response by keyword."""
+    filtered = [row for row in resp if row["content"]]  # remove empty keywords messages
     if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]):  # keyword is a SQL query
-        return resp
+        return filtered
+    if not keyword:
+        return filtered
     # remove consecutive whitespace
     while "  " in keyword:
         keyword = keyword.replace("  ", " ")
@@ -203,4 +207,4 @@ def filter_response(resp: list[dict], keyword: str) -> list[dict]:
 
     re_keyword = keyword.replace(" ", "(.*?)")
     pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
-    return [row for row in resp if pattern.search(row["content"])]
+    return [row for row in filtered if pattern.search(row["content"])]