Commit 6e8ce40
src/history/query.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
-import string
from io import BytesIO
from loguru import logger
@@ -9,10 +8,10 @@ from pyrogram.client import Client
from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
from pyrogram.types import Message
-from config import PREFIX, TZ, cache, cutter
+from config import PREFIX, TZ, cache
from database.turso import turso_exec, turso_parse_resp
from history.turso import get_turso_chatinfo, get_user, save_chatinfo_to_turso
-from history.utils import TURSO_KWARGS, check_save_history, get_chat, is_admin, list_chat_ids
+from history.utils import TURSO_KWARGS, check_save_history, filter_response, generate_query, get_chat, is_admin, list_chat_ids
from llm.utils import convert_html
from messages.parser import parse_chat, parse_msg
from messages.progress import modify_progress
@@ -182,11 +181,17 @@ def parse_queries(texts: str, qtype: str) -> tuple[str, str, str, str, str]:
async def query_turso(client: Client, cinfo: dict[str, str], match_time: str, user: str, keyword: str) -> dict:
- """Query chat history from Turso."""
- keyword = keyword.replace(",", " ") # comma can not be used in query
- segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
- texts_to_match = " ".join(segmented)
- sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{texts_to_match}'"
+ """Query chat history from Turso.
+
+ 由于LIKE查询会扫描整个表, 速度较慢, 而且会快速消耗读取数量配额, 因此我们使用FTS5搜索.
+ 由于FTS5不支持中文匹配, 且远端数据库不支持icu分词器, 所以在插入文本时手动进行了分词 (基于`cutword`库)
+ 分词结果可能并不准确, 但可以满足大部分需求.
+ 我们还需要对返回的结果进行进一步过滤, 以精确匹配.
+ """
+ search_query = generate_query(keyword)
+ if not search_query:
+ return {}
+ sql = f"SELECT T.mid, T.mtype, T.time, T.fullname, T.content FROM '{cinfo['tablename']}' AS T JOIN fts_{cinfo['cid']} AS FTS ON T.mid = FTS.rowid WHERE FTS.segmented MATCH '{search_query}'"
if match_time:
begin = "1970-01-01 00:00:00"
end = nowstr(TZ)
@@ -210,10 +215,11 @@ async def query_turso(client: Client, cinfo: dict[str, str], match_time: str, us
sql += " ORDER BY T.mid DESC"
logger.info(sql)
resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, retry=2, **TURSO_KWARGS)
+ filterd = filter_response(turso_parse_resp(resp), keyword)
full_texts = ""
texts = "" # long message will be trimmed
count = 0
- for row in turso_parse_resp(resp):
+ for row in filterd:
url = f"https://t.me/{cinfo['chandle']}/{row['mid']}" if cinfo["chandle"] else f"https://t.me/c/{cinfo['cid']}/{row['mid']}"
username = row["fullname"] or "消息链接"
emoji = MTYPE_EMOJI[row["mtype"]] if row["mtype"] != "text" else ""
src/history/utils.py
@@ -2,13 +2,15 @@
# -*- coding: utf-8 -*-
import contextlib
import os
+import re
+import string
from loguru import logger
from pyrogram.client import Client
from pyrogram.errors import PeerIdInvalid
from pyrogram.types import Chat, Message, User
-from config import DB, HISTORY, TID, cache
+from config import DB, HISTORY, TID, cache, cutter
from database.turso import turso_exec, turso_parse_resp
from messages.sender import send2tg
from others.emoji import CTYPE_EMOJI
@@ -148,3 +150,43 @@ async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) ->
user = member.user
break
return user
+
+
+def generate_query(keyword: str) -> str:
+ """Generate search query based on keyword."""
+ # ruff: noqa: RUF001
+ punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
+ for punc in punctuation: # remove pucntuation
+ keyword = keyword.replace(punc, " ")
+ keyword = keyword.replace("(", "(")
+ keyword = keyword.replace(")", ")")
+ # remove consecutive whitespace
+ while " " in keyword:
+ keyword = keyword.replace(" ", " ")
+ # remove leading and trailing whitespace
+ keyword = keyword.strip()
+ segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
+ final = []
+ length = len(segmented)
+ for i, word in enumerate(segmented):
+ if word.strip() in ["OR", "AND", "NOT", "(", ")"]:
+ final.append(word)
+ continue
+ final.append(f'"{word}"')
+ if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
+ final.append("AND")
+ return " ".join(final)
+
+
+def filter_response(resp: list[dict], keyword: str) -> list[dict]:
+ if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]): # keyword is a SQL query
+ return resp
+ # remove consecutive whitespace
+ while " " in keyword:
+ keyword = keyword.replace(" ", " ")
+ # remove leading and trailing whitespace
+ keyword = keyword.strip()
+
+ re_keyword = keyword.replace(" ", "(.*?)")
+ pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
+ return [row for row in resp if pattern.search(row["content"])]