Commit a2d4a0b
src/database/d1.py
@@ -34,19 +34,102 @@ async def create_d1_database(
@cache.memoize(ttl=0)
-async def create_d1_table(table_name: str | float, columns: str, db_name: str = "bennybot", *, silent: bool = False) -> None:
- """Create D1 database and return DatabaseID."""
+async def create_d1_table(
+ table_name: str | float,
+ columns: str,
+ *,
+ idx_cols: list[str] | None = None,
+ idx_prefix: str = "idx_",
+ fts_on_col: str | None = None,
+ fts_index_col: str = "segmented",
+ fts_name: str | None = None,
+ db_name: str = "bennybot",
+ silent: bool = False,
+) -> None:
+ """Create a D1 table.
+
+ If `idx_cols` is provided, create indexs for these columns.
+
+ idx_cols should be a list of strings, the created index names prefixed by `idx_prefix`
+ for example:
+ idx_prefix = "idx_"
+ idx_cols = ["uid", "time"]
+ indexs = ["idx_uid", "idx_time"]
+
+ # create FTS table for Chinese search
+ If `fts_on_col` is provided, create a FTS5 table with `fts_on_col` as the on column.
+ the `fts_index_col` is the column used for FTS5 indexing.
+ """
database_id = await create_d1_database(db_name, silent=silent)
if not database_id:
return
tables = await list_d1_tables(db_name, silent=silent)
- if table_name in tables:
- return
+ if table_name not in tables:
+ sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
+ resp = await query_d1(sql, database_id, silent=silent)
+ if resp.get("success"):
+ logger.success(f"Create Table {table_name} in D1 database {db_name}")
- sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
- await query_d1(sql, database_id, silent=silent)
- if not silent:
- logger.success(f"Create Table {table_name} in D1 database {db_name}")
+ # create indexs if idx_cols is not None
+ if idx_cols is not None:
+ resp = await query_d1("SELECT name FROM sqlite_master WHERE type='index';", db_id=database_id, silent=silent)
+ indexs = glom(resp, "result.0.results.*.name", default=[])
+ for idx_name in idx_cols:
+ if idx_name not in columns:
+ logger.warning(f"Index {idx_name} not in columns {columns}")
+ continue
+ if f"{idx_prefix}{idx_name}" not in indexs:
+ resp = await query_d1(f'CREATE INDEX IF NOT EXISTS "{idx_prefix}{idx_name}" ON "{table_name}"({idx_name})', db_id=database_id, silent=silent)
+ if resp.get("success"):
+ logger.success(f'Create Index "{idx_prefix}{idx_name}" of table "{table_name}" in D1 database "{db_name}"')
+
+ if fts_on_col is not None:
+ # 列出所有虚拟表
+ resp = await query_d1('SELECT name FROM pragma_table_list WHERE type="virtual";', db_id=database_id, silent=silent)
+ virtual_tables = flatten(glom(resp, "result.*.results.*.name", default=[]))
+ """创建 FTS5 虚拟表
+ -- content=table_name 指明关联的原表
+ -- content_rowid=fts_on_col 指明原表的行 ID 列是 fts_on_col
+ -- fts_index_col 是我们要索引的列
+ -- tokenize='unicode61' 使用 unicode61 分词器, 对多种语言支持更好
+ """
+ fts_table = f"fts_{table_name}" if fts_name is None else f"fts_{fts_name}"
+
+ if fts_table not in virtual_tables:
+ logger.debug(f"Creating FTS5 virtual table for {table_name}")
+ sql = f"CREATE VIRTUAL TABLE IF NOT EXISTS '{fts_table}' USING fts5({fts_index_col}, content='{table_name}', content_rowid={fts_on_col}, tokenize='unicode61');"
+ await query_d1(sql, db_id=database_id, silent=silent)
+
+ """将现有数据从原表复制到 FTS 表
+ 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=fts_on_col 指定的列) 和 content
+ 从原表中选择 fts_on_col 和 segmented 列。fts_on_col 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
+ """
+ sql = f"INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT {fts_on_col}, {fts_index_col} FROM '{table_name}' WHERE {fts_on_col} NOT IN (SELECT rowid FROM '{fts_table}');"
+ await query_d1(sql, db_id=database_id, silent=silent)
+
+ # 列出所有触发器
+ resp = await query_d1('SELECT name FROM sqlite_master WHERE type="trigger";', db_id=database_id, silent=silent)
+ triggers = flatten(glom(resp, "result.*.results.*.name", default=[]))
+ """维护 FTS 表
+ 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
+ 在原表插入、删除、更新时, 同步更新 FTS 表
+ """
+ trigger_prefix = f"trigger_{table_name}" if fts_name is None else f"trigger_{fts_name}"
+ # 创建触发器, 在原表插入数据时, 同步从 FTS 表插入
+ if f"{trigger_prefix}_ai" not in triggers:
+ sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ai' AFTER INSERT ON '{table_name}' BEGIN INSERT INTO '{fts_table}' (rowid, {fts_index_col}) VALUES (NEW.{fts_on_col}, NEW.{fts_index_col}); END;"
+ await query_d1(sql, db_id=database_id, silent=silent)
+
+ # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
+ if f"{trigger_prefix}_ad" not in triggers:
+ sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ad' AFTER DELETE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col}; END;"
+ await query_d1(sql, db_id=database_id, silent=silent)
+
+ # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
+ # FTS5 的更新通常是先删除旧的, 再插入新的
+ if f"{trigger_prefix}_au" not in triggers:
+ sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_au' AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col} AND OLD.{fts_index_col} <> NEW.{fts_index_col}; INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT NEW.{fts_on_col}, NEW.{fts_index_col} WHERE OLD.{fts_index_col} <> NEW.{fts_index_col}; END;"
+ await query_d1(sql, db_id=database_id, silent=silent)
@cache.memoize(ttl=600)
@@ -84,3 +167,18 @@ async def query_d1(
if not silent:
logger.trace(f"Query CF-D1: {payload}")
return await hx_req(api, "POST", json_data=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
+
+
+def insert_d1(table_name: str, records: dict, update_on_conflict: str = "") -> dict:
+ """Create a D1 insert SQL.
+
+ Returns:
+ dict: {"sql": sql, "params": params}
+ """
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values});"
+ if update_on_conflict:
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != update_on_conflict])
+ sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values}) ON CONFLICT ({update_on_conflict}) DO UPDATE SET {updates};"
+ return {"sql": sql, "params": list(records.values())}
src/history/d1.py
@@ -2,111 +2,142 @@
# -*- coding: utf-8 -*-
import asyncio
import json
+import os
from datetime import datetime, timedelta
from pathlib import Path
+from typing import Literal
from zoneinfo import ZoneInfo
-from glom import Coalesce, flatten, glom
+from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
-from pyrogram.errors import PeerIdInvalid
-from pyrogram.types import Chat, Message
+from pyrogram.types import Message
from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
-from database.d1 import create_d1_database, create_d1_table, query_d1
-from history.utils import check_save_history
-from messages.parser import parse_msg
-from utils import i_am_bot, nowdt
-
-DB_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, user TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, uid INTEGER, segmented TEXT"
-INDEX_NAMES = ["time", "uid"]
+from database.d1 import create_d1_table, insert_d1, query_d1
+from history.utils import CHAT_COLUMNS, MSG_COLUMNS, MSG_INDEXES, USER_COLUMNS, USER_INDEXES, check_save_history, fine_grained_check, get_chat
+from messages.parser import parse_chat, parse_msg
+from utils import i_am_bot, nowdt, slim_cid, to_int, true
async def sync_history_to_d1(client: Client, message: Message) -> None:
+ """Sync received messages to D1 database.
+
+ 1. save the user info to table `userinfo`
+ 2. save the chat info to table `chatinfo`
+ 3. save the message to table `{cid}-{ctitle}`
+ """
if not HISTORY.D1_ENABLE:
return
- info = parse_msg(message, silent=True)
- if not check_save_history(info["ctype"], info["cid"]):
+ info = parse_msg(message, silent=True, use_cache=False)
+ if not check_save_history(info["ctype"], info["cid"]) or not fine_grained_check(info) or message.service:
return
- table_name = await get_table_name(client, info["cid"])
+ await save_userinfo_to_d1(client, info)
+ chatinfo = await save_chatinfo_to_d1(client, info)
records = {
"mid": info["mid"],
+ "mtype": info["mtype"],
"time": info["time"],
- "user": info["full_name"],
+ "fullname": info["full_name"],
"content": message.content, # text or edited text
- "mtype": info["mtype"],
- "uid": info["uid"],
"filename": info["file_name"],
- "mime": info["mime_type"],
"urls": "\n\n".join(info["entity_urls"]),
"reply": message.reply_to_message_id,
+ "mime": info["mime_type"],
+ "user": info["full_name"].replace(" ", ""),
+ "handle": info["handle"],
+ "uid": info["uid"],
+ "gid": info["media_group_id"],
"segmented": " ".join(cutter.cutword(message.content)),
}
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- await query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True)
+ await query_d1(**insert_d1(chatinfo["tablename"], records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True)
+
+async def backup_chat_history_to_d1(
+ client: Client,
+ chat_id: str | int,
+ hours: float = HISTORY.BACKUP_CHATS_HOURS,
+ *,
+ start_from: Literal["latest", "oldest"] = "latest",
+ max_sync: float = float("inf"),
+) -> None:
+ """Backup chat history to D1 database.
-async def backup_chat_history_to_d1(client: Client, chat_id: str | int, hours: float = HISTORY.BACKUP_CHATS_HOURS) -> None:
+ If start_from is "oldest", find the minimum message id of this chat, then use this mid as `offset_id` to retrieve messages.
+ """
if not HISTORY.D1_ENABLE:
return
if await i_am_bot(client):
return
- chat = await chat_info(client, int(chat_id))
- table_name = await get_table_name(client, chat_id)
- # find message ids in this time range
+ chatinfo = await get_d1_chatinfo(chat_id)
+ if not chatinfo: # this chat is never synced
+ chat = await get_chat(client, to_int(chat_id))
+ chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
+ if not chatinfo: # chat is deleted
+ return
+ if true(os.getenv(f"HISTORY_IGNORE_{chatinfo['cid']}")):
+ return
+ table_name = chatinfo["tablename"]
now = nowdt(TZ)
begin_dt = now - timedelta(hours=hours)
begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
- end_time = now.strftime("%Y-%m-%d %H:%M:%S")
- sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}";'
- resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
- saved_mids = glom(resp, "result.0.results.*.mid", default=[])
- saved_mids = {int(x) for x in saved_mids}
- logger.info(f"Found {len(saved_mids)} messages in D1. Rows read: {glom(resp, 'result.0.meta.rows_read', default=1)}")
- concurrency = 200
+ if start_from == "oldest":
+ sql = f'SELECT mid FROM "{table_name}" ORDER BY mid ASC LIMIT 1'
+ resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
+ offset_id = glom(resp, "result.0.results.0.mid", default=1)
+ saved_mids = {int(offset_id)}
+ else:
+ # find message ids in this time range
+ end_time = now.strftime("%Y-%m-%d %H:%M:%S")
+ sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}"'
+ resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
+ saved_mids = glom(resp, "result.0.results.*.mid", default=[])
+ saved_mids = {int(x) for x in saved_mids}
+ offset_id = 0 # retrieve from latest message
+ logger.info(f"Found {len(saved_mids)} messages of {table_name} in D1")
+ concurrency = 100
+ num_sync = 0
tasks = []
- async for message in client.get_chat_history(chat.id): # type: ignore
- if not isinstance(message, Message) or message.empty:
+ real_cid = chatinfo["chandle"] or (int(chatinfo["cid"]) if chatinfo["ctype"] in ["BOT", "PRIVATE"] else int(f"-100{chatinfo['cid']}"))
+ async for message in client.get_chat_history(real_cid, offset_id=offset_id): # type: ignore
+ if not isinstance(message, Message) or message.empty or message.service or message.id in saved_mids:
continue
- info = parse_msg(message, silent=True)
- if info["mid"] in saved_mids:
+ info = parse_msg(message, silent=True, use_cache=False)
+ if not fine_grained_check(info):
continue
if info["time"] < begin_time:
break
+ if num_sync >= max_sync:
+ break
+ num_sync += 1
records = {
"mid": info["mid"],
- "time": info["time"],
- "user": info["full_name"],
- "content": info["text"],
"mtype": info["mtype"],
- "uid": info["uid"],
+ "time": info["time"],
+ "fullname": info["full_name"],
+ "content": message.content,
"filename": info["file_name"],
- "mime": info["mime_type"],
"urls": "\n\n".join(info["entity_urls"]),
"reply": message.reply_to_message_id,
+ "mime": info["mime_type"],
+ "user": info["full_name"].replace(" ", ""),
+ "handle": info["handle"],
+ "uid": info["uid"],
+ "gid": info["media_group_id"],
"segmented": " ".join(cutter.cutword(info["text"])),
}
- logger.trace(f"Syncing message {table_name}: {info['mid']}")
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
+ logger.trace(f"Syncing {table_name}: {info['mid']}")
+ tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
if len(tasks) == concurrency:
res = await asyncio.gather(*tasks, return_exceptions=True)
num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ logger.success(f"Synced {num_success} messages to D1")
tasks = []
+
if tasks:
res = await asyncio.gather(*tasks, return_exceptions=True)
num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
+ logger.success(f"Synced {num_success} messages to D1")
async def upload_exported_history_to_d1(client: Client, path: str | Path | None = None) -> None:
@@ -132,23 +163,36 @@ async def upload_exported_history_to_d1(client: Client, path: str | Path | None
with path.open("r") as f: # noqa: ASYNC230
data = json.load(f)
+ logger.info(f"Found {len(data['messages'])} messages in json file")
+ """Since the exported history does not has media_group_id,
+ So we first process all messages and add media_group_id for it.
+ If two consecutive messages have the same `from_id` and `date_unixtime`,
+ and the message type is photo or video, these messages will be considered as a media group.
+ """
+ last_msg = {}
+ for idx, msg in enumerate(data["messages"]):
+ if all(msg.get(key) == last_msg.get(key) for key in ["from_id", "date_unixtime"]) and any(key in msg for key in ["photo", "thumbnail"]):
+ data["messages"][idx - 1]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
+ data["messages"][idx]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
+ last_msg = msg
+
mtypes = {
- "text": "text",
- "photo": "photo",
- "animation": "animation",
"audio_file": "audio",
- "sticker": "sticker",
"voice_message": "voice",
"video_message": "video",
"video_file": "video",
}
- table_name = await get_table_name(client, data["id"])
+ chat_id = data["id"]
+ chatinfo = await get_d1_chatinfo(chat_id)
+ if not chatinfo: # this chat is never synced
+ chat = await get_chat(client, int(chat_id))
+ chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
+ table_name = chatinfo["tablename"]
# find all message_ids
- sql = f'SELECT mid FROM "{table_name}" ORDER BY mid;'
- resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
+ resp = await query_d1(f'SELECT mid FROM "{table_name}";', db_name=HISTORY.D1_DATABASE, silent=True)
saved_ids = glom(resp, "result.0.results.*.mid", default=[])
saved_ids = {int(x) for x in saved_ids}
- concurrency = 200
+ concurrency = 100
tasks = []
for info in [msg for msg in data["messages"] if msg["id"] not in saved_ids]: # type: ignore
if info["type"] != "message":
@@ -160,144 +204,168 @@ async def upload_exported_history_to_d1(client: Client, path: str | Path | None
info["media_type"] = "photo"
if "video/" in info.get("mime_type", ""):
info["media_type"] = "video_file"
-
+ mtype = info.get("media_type", "text")
+ content = parse_text(info.get("text", []))
+ urls = parse_urls(info.get("text_entities", []))
+ # fine-grained check requires key: ["cid", "mtype", "text", "entity_urls"]
+ if not fine_grained_check({"cid": data["id"], "mtype": mtype, "text": content, "entity_urls": urls}):
+ continue
dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
- user = info["from"]
+ user = info["from"] or info["from_id"].removeprefix("user").removeprefix("channel")
if user == data["name"] and data["type"] in ["public_channel", "private_channel"]: # user is not shown
user = ""
uid = 1
- content = parse_text(info.get("text", []))
records = {
"mid": info["id"],
+ "mtype": mtypes.get(mtype, mtype),
"time": dt.strftime("%Y-%m-%d %H:%M:%S"),
- "user": user,
- "content": parse_text(info.get("text", [])),
- "mtype": mtypes[info.get("media_type", "text")],
- "uid": uid,
+ "fullname": user,
+ "content": content,
"filename": info.get("file_name", ""),
- "mime": info.get("mime_type", ""),
- "urls": parse_urls(info.get("text_entities", [])),
+ "urls": urls,
"reply": info.get("reply_to_message_id"),
+ "mime": info.get("mime_type", ""),
+ "user": user.replace(" ", ""),
+ "handle": "", # TODO: parse handle
+ "uid": uid,
+ "gid": info.get("media_group_id", 0),
"segmented": " ".join(cutter.cutword(content)),
}
- logger.trace(f"Syncing message {table_name}: {info['id']}")
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
+ # logger.debug(f"Syncing message {table_name}: {info['id']}")
+ tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
if len(tasks) == concurrency:
res = await asyncio.gather(*tasks, return_exceptions=True)
num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ logger.success(f"Synced {num_success} messages to D1")
tasks = []
if tasks:
res = await asyncio.gather(*tasks, return_exceptions=True)
num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
-
-
-async def chat_info(client: Client, chat_id: int) -> Chat:
- if cache.get(f"chat-info-{chat_id}"):
- return cache.get(f"chat-info-{chat_id}")
- if chat_id == 1:
- return Chat(id=1)
- try:
- logger.debug(f"Getting chat info for {chat_id}")
- chat = await client.get_chat(int(chat_id))
- except PeerIdInvalid:
- return await chat_info(client, int(f"-100{chat_id}"))
- except Exception:
- chat = Chat(id=1)
- cache.set(f"chat-info-{chat_id}", chat, ttl=3600) # cache for 1 hour
- return chat
-
-
-async def get_table_name(client: Client, chat_id: str | int) -> str:
- """Get D1 table name by chat id."""
- if cache.get(f"tablename-{chat_id}"):
- return cache.get(f"tablename-{chat_id}")
- # get a default table name
- chat = await chat_info(client, int(chat_id))
- first_name = chat.first_name or ""
- last_name = chat.last_name or ""
- full_name = first_name + last_name
- chat_title = full_name or chat.title or ""
- slim_cid = str(chat_id).removeprefix("-100")
- default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
-
- # get database id, and create database if not exists
- database_id = await create_d1_database(name=HISTORY.D1_DATABASE)
- if not database_id:
- await create_d1_table(default_name, DB_COLUMNS, HISTORY.D1_DATABASE)
- await create_table_index(slim_cid, database_id, default_name)
- return default_name
-
- # find the table name based on chat id
- sql = "SELECT name FROM sqlite_master WHERE type='table';"
- resp = await query_d1(sql, database_id)
- table_names = flatten(glom(resp, "result.*.results.*.name", default=[]))
- table_name = next((x for x in table_names if x.startswith(slim_cid + "-")), default_name)
- cache.set(f"tablename-{chat_id}", table_name, ttl=0)
-
- # create table and index
- await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE)
- await create_table_index(slim_cid, database_id, table_name)
- return table_name
-
-
-async def create_table_index(slim_cid: str, database_id: str, table_name: str) -> None:
- # get all index names
- sql = "SELECT name FROM sqlite_master WHERE type='index';"
- resp = await query_d1(sql, database_id)
- indexs = flatten(glom(resp, "result.*.results.*.name", default=[]))
-
- # create index if not exists
- idx_names = [x for x in INDEX_NAMES if f"idx_{slim_cid}_{x}" not in indexs]
- if not idx_names:
- return
+ logger.success(f"Synced {num_success} messages to D1")
+
+
+async def get_d1_userinfo(uid: int, cid: int) -> dict:
+ """Get user info from table `userinfo`.
- # 创建标准索引
- for idx_name in idx_names:
- logger.debug(f"Creating index on {table_name} for {idx_name}")
- sql = f'CREATE INDEX IF NOT EXISTS idx_{slim_cid}_{idx_name} ON "{table_name}"({idx_name})'
- await query_d1(sql, database_id, silent=True)
-
- """创建 FTS5 虚拟表
- -- content=table_name 指明关联的原表
- -- content_rowid=mid 指明原表的行 ID 列是 mid
- -- segmented 是我们要索引的列
- -- tokenize='unicode61' 使用 unicode61 分词器
+ Returns:
+ uid, full_name, handle
"""
- sql = f"""CREATE VIRTUAL TABLE IF NOT EXISTS fts_{slim_cid} USING fts5(segmented, content="{table_name}", content_rowid=mid, tokenize="unicode61");"""
- await query_d1(sql, database_id, silent=True)
+ # create table
+ await create_d1_table(
+ "userinfo",
+ USER_COLUMNS,
+ db_name=HISTORY.D1_DATABASE,
+ idx_cols=USER_INDEXES,
+ idx_prefix="idx_userinfo_",
+ silent=True,
+ )
+ resp = await query_d1(f"SELECT * FROM userinfo WHERE uid={uid} AND cid={cid};", db_name=HISTORY.D1_DATABASE, silent=True)
+ return glom(resp, "result.0.results.0", default={})
- """将现有数据从原表复制到 FTS 表
- 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=mid 指定的列) 和 segmented
- 从原表中选择 mid 和 segmented 列。mid 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
+
+async def save_userinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
+ """Save user info to table `userinfo`.
+
+ Args:
+ minfo (dict): parsed message info.
+
+ Returns:
+ uid, full_name, handle, tags
"""
- sql = f"INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT mid, segmented FROM '{table_name}' WHERE mid NOT IN (SELECT rowid FROM fts_{slim_cid});"
- await query_d1(sql, database_id, silent=True)
+ uid = int(minfo["uid"])
+ cid = int(slim_cid(minfo["cid"]))
+ if uid == 1: # default user (user is unknown)
+ return {}
+ # Get user info from turso and save it to cache
+ if not (cached := cache.get(f"userinfo-{uid}-{cid}")):
+ cached = await get_d1_userinfo(uid, cid)
+ cache.set(f"userinfo-{uid}-{cid}", cached, ttl=0)
- """维护 FTS 表
- 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
- 在原表插入、删除、更新时, 同步更新 FTS 表
+ ctitle = minfo["ctitle"] or minfo["full_name"]
+ # if in private chats, we use the opponent's name as chat title
+ if minfo["ctype"] in ["BOT", "PRIVATE"]:
+ chat = await get_chat(client, minfo["cid"])
+ if chat.id != 0:
+ ctitle = parse_chat(chat)["ctitle"]
+
+ primary_key = uid if uid == cid else abs(uid - cid)
+ records = {
+ "ctitle": ctitle,
+ "full_name": minfo["full_name"],
+ "handle": minfo["handle"],
+ "tags": cached.get("tags", ""),
+ "name": minfo["full_name"].replace(" ", ""),
+ "uid": uid,
+ "cid": cid,
+ "id": primary_key,
+ }
+ if cached != records:
+ logger.info(f"Save user info: {records}")
+ cache.set(f"userinfo-{uid}-{cid}", records, ttl=0)
+ await query_d1(**insert_d1("userinfo", records, update_on_conflict="id"), db_name=HISTORY.D1_DATABASE, silent=True)
+ return records
+
+
+async def get_d1_chatinfo(cid: str | int) -> dict:
+ """Get chat info from table `chatinfo`.
+
+ Returns:
+ cid, ctype, ctitle, chandle
+ """
+ # create table
+ await create_d1_table("chatinfo", CHAT_COLUMNS, db_name=HISTORY.D1_DATABASE, silent=True)
+ resp = await query_d1(f"SELECT * FROM chatinfo WHERE cid='{slim_cid(cid)}' OR chandle='{cid}';", db_name=HISTORY.D1_DATABASE, silent=True)
+ return glom(resp, "result.0.results.0", default={})
+
+
+async def save_chatinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
+ """Save chat info to table `chatinfo`.
+
+ Args:
+ minfo (dict): parsed message info.
+
+ Returns:
+ cid, ctype, ctitle, chandle, tablename, tags
"""
- # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ai AFTER INSERT ON '{table_name}' BEGIN INSERT INTO fts_{slim_cid} (rowid, segmented) VALUES (NEW.mid, NEW.segmented); END;"
- await query_d1(sql, database_id, silent=True)
-
- # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ad AFTER DELETE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid; END;"
- await query_d1(sql, database_id, silent=True)
-
- # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
- # FTS5 的更新通常是先删除旧的, 再插入新的
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_au AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid AND OLD.segmented <> NEW.segmented; INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT NEW.mid, NEW.segmented WHERE OLD.segmented <> NEW.segmented; END;"
- await query_d1(sql, database_id, silent=True)
+ cid = slim_cid(minfo["cid"])
+ if str(cid) == "0":
+ return {}
+ # Get chat info from turso and save it to cache
+ if not (cached := cache.get(f"chatinfo-{cid}")):
+ cached = await get_d1_chatinfo(cid)
+ cache.set(f"chatinfo-{cid}", cached, ttl=0)
+
+ ctitle = minfo["ctitle"] or minfo["full_name"]
+ # if in private chats, we use the opponent's name as chat title
+ if minfo["ctype"] in ["BOT", "PRIVATE"]:
+ chat = await get_chat(client, minfo["cid"])
+ if chat.id != 0:
+ ctitle = parse_chat(chat)["ctitle"]
+
+ records = {
+ "cid": int(cid),
+ "ctype": minfo["ctype"],
+ "ctitle": ctitle,
+ "chandle": minfo["chandle"],
+ "tablename": cached.get("tablename", "") or f"{cid}-{ctitle}",
+ "tags": cached.get("tags", ""),
+ }
+ # create table for this chat
+ await create_d1_table(
+ records["tablename"],
+ MSG_COLUMNS,
+ idx_cols=MSG_INDEXES,
+ idx_prefix=f"idx_{cid}_",
+ fts_on_col="mid",
+ fts_name=cid,
+ db_name=HISTORY.D1_DATABASE,
+ silent=True,
+ )
+ if cached != records:
+ logger.info(f"Save chat info: {records}")
+ cache.set(f"chatinfo-{cid}", records, ttl=0)
+ await query_d1(**insert_d1("chatinfo", records, update_on_conflict="cid"), db_name=HISTORY.D1_DATABASE, silent=True)
+ return records
src/history/sync.py
@@ -1,12 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+from typing import Literal
+
+from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
from config import HISTORY, cache
from database.turso import turso_exec, turso_parse_resp
-from history.d1 import backup_chat_history_to_d1, sync_history_to_d1
+from history.d1 import backup_chat_history_to_d1, query_d1, sync_history_to_d1
from history.turso import backup_chat_history_to_turso, sync_history_to_turso
from history.utils import TURSO_KWARGS
@@ -14,16 +17,18 @@ from history.utils import TURSO_KWARGS
async def sync_chat_history(client: Client, message: Message) -> None:
if not HISTORY.ENABLE:
return
- if HISTORY.ENGINE.upper() == "D1": # Deprecated
- await sync_history_to_d1(client, message)
- if HISTORY.ENGINE.upper() == "TURSO":
+ if "TURSO" in HISTORY.ENGINE.upper():
await sync_history_to_turso(client, message)
+ if "D1" in HISTORY.ENGINE.upper():
+ await sync_history_to_d1(client, message)
async def backup_chat_history(
client: Client,
chats: str = HISTORY.PERIODICALLY_BACKUP_CHATS,
hours: float = HISTORY.BACKUP_CHATS_HOURS,
+ *,
+ start_from: Literal["latest", "oldest"] = "latest",
) -> None:
if not HISTORY.ENABLE:
return
@@ -33,16 +38,23 @@ async def backup_chat_history(
cache.set("backup_chat_history", 1, ttl=12 * 3600) # backup every 12 hours
# if `chats` is set to "full_table", backup all chats in `chatinfo` table
if chats == "full_table":
- resp = await turso_exec([{"type": "execute", "stmt": {"sql": "SELECT * FROM 'chatinfo';"}}], silent=True, retry=2, **TURSO_KWARGS)
- tables = turso_parse_resp(resp)
- chat_ids = [x["chandle"] or int(x["cid"]) for x in tables]
+ if "TURSO" in HISTORY.ENGINE.upper():
+ resp = await turso_exec([{"type": "execute", "stmt": {"sql": "SELECT * FROM 'chatinfo';"}}], silent=True, retry=2, **TURSO_KWARGS)
+ tables = turso_parse_resp(resp)
+ for cid in [x["chandle"] or int(x["cid"]) for x in tables]:
+ logger.info(f"Backup chat history to Turso: {cid}")
+ await backup_chat_history_to_turso(client, cid, hours, start_from=start_from)
+ if "D1" in HISTORY.ENGINE.upper():
+ resp = await query_d1("SELECT * FROM 'chatinfo';", db_name=HISTORY.D1_DATABASE, silent=True)
+ tables = glom(resp, "result.0.results", default=[])
+ for cid in [x["chandle"] or int(x["cid"]) for x in tables]:
+ logger.info(f"Backup chat history to D1: {cid}")
+ await backup_chat_history_to_d1(client, cid, hours, start_from=start_from)
else:
chat_ids = [x.strip() for x in chats.split(",") if x.strip()]
- if not chat_ids:
- return
- for cid in chat_ids:
- logger.info(f"Backup chat history: {cid}")
- if HISTORY.ENGINE.upper() == "D1":
- await backup_chat_history_to_d1(client, cid, hours)
- elif HISTORY.ENGINE.upper() == "TURSO":
- await backup_chat_history_to_turso(client, cid, hours)
+ for cid in chat_ids:
+ logger.info(f"Backup chat history: {cid}")
+ if "TURSO" in HISTORY.ENGINE.upper():
+ await backup_chat_history_to_turso(client, cid, hours, start_from=start_from)
+ if "D1" in HISTORY.ENGINE.upper():
+ await backup_chat_history_to_d1(client, cid, hours, start_from=start_from)
src/history/turso.py
@@ -4,6 +4,7 @@ import json
import os
from datetime import datetime, timedelta
from pathlib import Path
+from typing import Literal
from zoneinfo import ZoneInfo
from glom import Coalesce, flatten, glom
@@ -13,17 +14,10 @@ from pyrogram.types import Message
from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
from database.turso import insert_statement, turso_create_table, turso_exec, turso_parse_resp
-from history.utils import TURSO_KWARGS, check_save_history, fine_grained_check, get_chat
+from history.utils import CHAT_COLUMNS, MSG_COLUMNS, MSG_INDEXES, TURSO_KWARGS, USER_COLUMNS, USER_INDEXES, check_save_history, fine_grained_check, get_chat
from messages.parser import parse_chat, parse_msg
from utils import i_am_bot, nowdt, slim_cid, to_int, true
-CHAT_COLUMNS = "cid INTEGER PRIMARY KEY, ctype TEXT, ctitle TEXT, chandle TEXT, tablename TEXT, tags TEXT"
-USER_COLUMNS = "ctitle TEXT, full_name TEXT, handle TEXT, tags TEXT, name TEXT, uid INTEGER, cid INTEGER, id INTEGER PRIMARY KEY"
-USER_INDEXES = ["uid", "cid"]
-
-MSG_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, fullname TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, user TEXT, handle TEXT, uid INTEGER, gid INTEGER, segmented TEXT" # fmt: off
-MSG_INDEXES = ["time", "user", "uid", "handle"]
-
async def sync_history_to_turso(client: Client, message: Message) -> None:
"""Sync received messages to Turso database.
@@ -59,7 +53,18 @@ async def sync_history_to_turso(client: Client, message: Message) -> None:
await turso_exec([insert_statement(chatinfo["tablename"], records, update_on_conflict="mid")], silent=True, retry=2, **TURSO_KWARGS)
-async def backup_chat_history_to_turso(client: Client, chat_id: str | int, hours: float = HISTORY.BACKUP_CHATS_HOURS) -> None:
+async def backup_chat_history_to_turso(
+ client: Client,
+ chat_id: str | int,
+ hours: float = HISTORY.BACKUP_CHATS_HOURS,
+ *,
+ start_from: Literal["latest", "oldest"] = "latest",
+ max_sync: float = float("inf"),
+) -> None:
+ """Backup chat history to Turso database.
+
+ If start_from is "oldest", find the minimum message id of this chat, then use this mid as `offset_id` to retrieve messages.
+ """
if not HISTORY.TURSO_ENABLE:
return
if await i_am_bot(client):
@@ -73,20 +78,28 @@ async def backup_chat_history_to_turso(client: Client, chat_id: str | int, hours
if true(os.getenv(f"HISTORY_IGNORE_{chatinfo['cid']}")):
return
table_name = chatinfo["tablename"]
- # find message ids in this time range
now = nowdt(TZ)
begin_dt = now - timedelta(hours=hours)
begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
- end_time = now.strftime("%Y-%m-%d %H:%M:%S")
- sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}";'
- resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, **TURSO_KWARGS)
- saved_mids = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
- saved_mids = {int(x) for x in saved_mids}
+ if start_from == "oldest":
+ sql = f'SELECT mid FROM "{table_name}" ORDER BY mid ASC LIMIT 1'
+ resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, **TURSO_KWARGS)
+ offset_id = int(glom(resp, "results.0.response.result.rows.0.0.value", default=1))
+ saved_mids = {offset_id}
+ else:
+ # find message ids in this time range
+ end_time = now.strftime("%Y-%m-%d %H:%M:%S")
+ sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}";'
+ resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], silent=True, **TURSO_KWARGS)
+ saved_mids = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
+ saved_mids = {int(x) for x in saved_mids}
+ offset_id = 0 # retrieve from latest message
logger.info(f"Found {len(saved_mids)} messages in Turso. Rows read: {glom(resp, 'results.0.response.result.rows_read', default=1)}")
concurrency = 1000
+ num_sync = 0
statements = []
real_cid = chatinfo["chandle"] or (int(chatinfo["cid"]) if chatinfo["ctype"] in ["BOT", "PRIVATE"] else int(f"-100{chatinfo['cid']}"))
- async for message in client.get_chat_history(real_cid): # type: ignore
+ async for message in client.get_chat_history(real_cid, offset_id=offset_id): # type: ignore
if not isinstance(message, Message) or message.empty or message.service or message.id in saved_mids:
continue
info = parse_msg(message, silent=True, use_cache=False)
@@ -94,6 +107,9 @@ async def backup_chat_history_to_turso(client: Client, chat_id: str | int, hours
continue
if info["time"] < begin_time:
break
+ if num_sync >= max_sync:
+ break
+ num_sync += 1
records = {
"mid": info["mid"],
"mtype": info["mtype"],
src/history/utils.py
@@ -22,6 +22,13 @@ TURSO_KWARGS: dict = {
"group_token": HISTORY.TURSO_GROUP_TOKEN or DB.TURSO_GROUP_TOKEN,
}
+CHAT_COLUMNS = "cid INTEGER PRIMARY KEY, ctype TEXT, ctitle TEXT, chandle TEXT, tablename TEXT, tags TEXT"
+USER_COLUMNS = "ctitle TEXT, full_name TEXT, handle TEXT, tags TEXT, name TEXT, uid INTEGER, cid INTEGER, id INTEGER PRIMARY KEY"
+USER_INDEXES = ["uid", "cid"]
+
+MSG_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, fullname TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, user TEXT, handle TEXT, uid INTEGER, gid INTEGER, segmented TEXT" # fmt: off
+MSG_INDEXES = ["time", "user", "uid", "handle"]
+
@cache.memoize(ttl=0)
def check_save_history(ctype: str, cid: int | str) -> bool: