Commit 7745834
src/history/query.py
@@ -14,16 +14,7 @@ from database.d1 import query_d1
from database.turso import turso_exec, turso_parse_resp
from history.d1 import get_d1_chatinfo, save_chatinfo_to_d1
from history.turso import get_turso_chatinfo, save_chatinfo_to_turso
-from history.utils import (
- TURSO_KWARGS,
- check_save_history,
- filter_response,
- generate_query,
- get_chat,
- get_user_from_chat,
- is_admin,
- list_chat_ids,
-)
+from history.utils import TURSO_KWARGS, check_save_history, filter_response, get_chat, get_user_from_chat, is_admin, keyword_query, list_chat_ids
from llm.utils import convert_html
from messages.parser import parse_chat, parse_msg
from messages.progress import modify_progress
@@ -120,7 +111,6 @@ async def query_chat_history(client: Client, message: Message, **kwargs):
if not texts:
await modify_progress(text=caption + "\n⚠️未匹配任何记录", force_update=True, **kwargs)
return
-
if len(texts) < 20480 and len(await smart_split(texts)) == 1:
await modify_progress(message=status_msg, text=blockquote(texts), force_update=True, **kwargs)
return
@@ -172,10 +162,6 @@ def parse_queries(texts: str, qtype: str) -> tuple[str, str, str, str, str]:
user = matched.group(1)
keyword = re.sub(rf"^@{user}", "", texts).lstrip() # remove user
- # error handling
- if not keyword:
- error = f"查询格式有误, 必须包含 `关键词`\n请发送 `{PREFIX.HISTORY}` 命令查看帮助"
-
if qtype == "hist":
if not any((match_time, user, keyword)):
error = f"查询格式有误, 请发送 `{PREFIX.HISTORY}` 命令查看帮助"
@@ -206,10 +192,10 @@ async def query_history(
分词结果可能并不准确, 但可以满足大部分需求.
我们还需要对返回的结果进行进一步过滤, 以精确匹配.
"""
- search_query = generate_query(keyword)
- if not search_query:
- return {}
- sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{search_query}'"
+ conditions = []
+ if keyword:
+ conditions.append(keyword_query(keyword))
+
if match_time:
begin = "1970-01-01 00:00:00"
end = nowstr(TZ)
@@ -222,16 +208,30 @@ async def query_history(
elif len(match_time) == 10: # 2025-01-01
begin = f"{match_time} 00:00:00"
end = f"{match_time} 23:59:59"
- sql += f" AND T.time >= '{begin}' AND T.time <= '{end}'"
+ conditions.append(f"T.time >= '{begin}' AND T.time <= '{end}'")
if user:
# 由于username可以修改, 我们优先使用UID进行匹配
real_cid = cinfo["chandle"] if cinfo.get("chandle") else cinfo["cid"] if cinfo["ctype"] in ["BOT", "PRIVATE"] else f"-100{cinfo['cid']}"
if uid := await get_uid_by_username(client, real_cid, user, engine):
- sql += f" AND T.uid = {uid}"
+ conditions.append(f"T.uid = {uid}")
else:
- sql += f" AND T.user = '{user}'"
- sql += " ORDER BY T.mid DESC"
+ conditions.append(f"T.user = '{user}'")
+ limit = 200
+ if conditions:
+ condition = " AND ".join(conditions)
+ sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE {condition} ORDER BY T.mid DESC"
+ else:
+ sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid ORDER BY T.mid DESC"
+ if keyword or match_time:
+ limit = 99999
+ sql += f" LIMIT {limit}"
logger.info(sql)
+ limit_to_single_msg = False
+ if not any((match_time, user, keyword)):
+ limit_to_single_msg = True
+ elif user and (not any((match_time, keyword))): # only user
+ limit_to_single_msg = True
+
if engine == "turso":
resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, retry=2, **TURSO_KWARGS)
filterd = filter_response(turso_parse_resp(resp), keyword)
@@ -255,8 +255,16 @@ async def query_history(
end = min(len(content), idx + len(keyword) + 45)
end_suffix = "..." if end != len(content) else ""
content = f"{begin_prefix}{row['content'][begin:end]}{end_suffix}"
- texts += f"\n👤[{username}]({url}) {row['time']}{emoji}:\n{content}\n"
- count += 1
+ entry = f"\n👤[{username}]({url}) {row['time']}{emoji}:\n{content}\n"
+ if limit_to_single_msg:
+ if len(await smart_split(texts + entry)) == 1:
+ texts += entry
+ count += 1
+ else:
+ break
+ else:
+ texts += entry
+ count += 1
return {"texts": texts.strip(), "full_texts": full_texts.strip(), "count": count}
src/history/utils.py
@@ -166,7 +166,7 @@ async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) ->
return user
-def generate_query(keyword: str) -> str:
+def keyword_query(keyword: str) -> str:
"""Generate search query based on keyword."""
# ruff: noqa: RUF001
punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
@@ -189,12 +189,16 @@ def generate_query(keyword: str) -> str:
final.append(f'"{word}"')
if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
final.append("AND")
- return " ".join(final)
+ return f"""FTS.segmented MATCH '{" ".join(final)}'"""
def filter_response(resp: list[dict], keyword: str) -> list[dict]:
+ """Filter response by keyword."""
+ filtered = [row for row in resp if row["content"]] # remove empty keywords messages
if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]): # keyword is a SQL query
- return resp
+ return filtered
+ if not keyword:
+ return filtered
# remove consecutive whitespace
while " " in keyword:
keyword = keyword.replace(" ", " ")
@@ -203,4 +207,4 @@ def filter_response(resp: list[dict], keyword: str) -> list[dict]:
re_keyword = keyword.replace(" ", "(.*?)")
pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
- return [row for row in resp if pattern.search(row["content"])]
+ return [row for row in filtered if pattern.search(row["content"])]