Commit 8140220
Changed files (7)
src/history/d1.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import json
+from datetime import datetime, timedelta
+from pathlib import Path
+from zoneinfo import ZoneInfo
+
+from glom import Coalesce, flatten, glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.errors import PeerIdInvalid
+from pyrogram.types import Chat, Message
+
+from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
+from database.d1 import create_d1_database, create_d1_table, query_d1
+from messages.parser import parse_msg
+from permission import check_save_history
+from utils import i_am_bot, nowdt
+
+# ruff: noqa: S608
+
+DB_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, user TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, uid INTEGER, segmented TEXT"
+INDEX_NAMES = ["time", "uid"]
+
+
+async def sync_history_to_d1(client: Client, message: Message) -> None:
+ if not HISTORY.D1_ENABLE:
+ return
+ info = parse_msg(message, silent=True)
+ if not check_save_history(info["ctype"], info["cid"]):
+ return
+ table_name = await get_table_name(client, info["cid"])
+ records = {
+ "mid": info["mid"],
+ "time": info["time"],
+ "user": info["full_name"],
+ "content": message.content, # text or edited text
+ "mtype": info["mtype"],
+ "uid": info["uid"],
+ "filename": info["file_name"],
+ "mime": info["mime_type"],
+ "urls": "\n\n".join(info["entity_urls"]),
+ "reply": message.reply_to_message_id,
+ "segmented": " ".join(cutter.cutword(message.content)),
+ }
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ await query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True)
+
+
+async def backup_chat_history_to_d1(client: Client, chat_id: str | int, hours: float = HISTORY.BACKUP_CHATS_HOURS) -> None:
+ if not HISTORY.D1_ENABLE:
+ return
+ if await i_am_bot(client):
+ return
+ chat = await chat_info(client, int(chat_id))
+ table_name = await get_table_name(client, chat_id)
+ # find message ids in this time range
+ now = nowdt(TZ)
+ begin_dt = now - timedelta(hours=hours)
+ begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
+ end_time = now.strftime("%Y-%m-%d %H:%M:%S")
+ sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}";'
+ resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
+ saved_mids = glom(resp, "result.0.results.*.mid", default=[])
+ logger.info(f"Found {len(saved_mids)} messages in D1. Rows read: {glom(resp, 'result.0.meta.rows_read', default=1)}")
+ concurrency = 200
+ tasks = []
+ async for message in client.get_chat_history(chat.id): # type: ignore
+ if not isinstance(message, Message) or message.empty:
+ continue
+ info = parse_msg(message, silent=True)
+ if info["mid"] in saved_mids:
+ continue
+ if info["time"] < begin_time:
+ break
+ records = {
+ "mid": info["mid"],
+ "time": info["time"],
+ "user": info["full_name"],
+ "content": info["text"],
+ "mtype": info["mtype"],
+ "uid": info["uid"],
+ "filename": info["file_name"],
+ "mime": info["mime_type"],
+ "urls": "\n\n".join(info["entity_urls"]),
+ "reply": message.reply_to_message_id,
+ "segmented": " ".join(cutter.cutword(info["text"])),
+ }
+ logger.trace(f"Syncing message {table_name}: {info['mid']}")
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
+ if len(tasks) == concurrency:
+ res = await asyncio.gather(*tasks, return_exceptions=True)
+ num_success = sum(glom(res, "*.success"))
+ if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
+ logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ tasks = []
+ if tasks:
+ res = await asyncio.gather(*tasks, return_exceptions=True)
+ num_success = sum(glom(res, "*.success"))
+ if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
+ logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ tasks = []
+
+
+async def upload_exported_history_to_d1(client: Client, path: str | Path | None = None) -> None:
+ if not HISTORY.D1_ENABLE:
+ return
+ if path is None:
+ path = Path(DOWNLOAD_DIR) / "result.json"
+ path = Path(path)
+ if not path.is_file():
+ return
+
+ def parse_text(texts: list) -> str:
+ if isinstance(texts, str):
+ return texts
+ text = ""
+ for x in texts:
+ text += x if isinstance(x, str) else x.get("text", "")
+ return text
+
+ def parse_urls(entities: list) -> str:
+ urls = [glom(x, Coalesce("href", "text")) for x in entities if x["type"] in {"link", "text_link"}]
+ return "\n\n".join(urls)
+
+ with path.open("r") as f: # noqa: ASYNC230
+ data = json.load(f)
+ mtypes = {
+ "text": "text",
+ "photo": "photo",
+ "animation": "animation",
+ "audio_file": "audio",
+ "sticker": "sticker",
+ "voice_message": "voice",
+ "video_message": "video",
+ "video_file": "video",
+ }
+ table_name = await get_table_name(client, data["id"])
+ # find all message_ids
+ sql = f'SELECT mid FROM "{table_name}" ORDER BY mid;'
+ resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
+ saved_ids = glom(resp, "result.0.results.*.mid", default=[])
+ concurrency = 200
+ tasks = []
+ for info in data["messages"]: # type: ignore
+ if info["id"] in saved_ids:
+ continue
+ if info["type"] != "message":
+ continue
+ if info["date_unixtime"] == "0":
+ continue
+ if "media_type" not in info: # guess mtype
+ if "photo" in info:
+ info["media_type"] = "photo"
+ if "video/" in info.get("mime_type", ""):
+ info["media_type"] = "video_file"
+
+ dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
+ uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
+ user = info["from"]
+ if user == data["name"] and data["type"] in ["public_channel", "private_channel"]: # user is not shown
+ user = ""
+ uid = 1
+
+ content = parse_text(info.get("text", []))
+ records = {
+ "mid": info["id"],
+ "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
+ "user": user,
+ "content": parse_text(info.get("text", [])),
+ "mtype": mtypes[info.get("media_type", "text")],
+ "uid": uid,
+ "filename": info.get("file_name", ""),
+ "mime": info.get("mime_type", ""),
+ "urls": parse_urls(info.get("text_entities", [])),
+ "reply": info.get("reply_to_message_id"),
+ "segmented": " ".join(cutter.cutword(content)),
+ }
+ logger.trace(f"Syncing message {table_name}: {info['id']}")
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
+ if len(tasks) == concurrency:
+ res = await asyncio.gather(*tasks, return_exceptions=True)
+ num_success = sum(glom(res, "*.success"))
+ if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
+ logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ tasks = []
+
+ if tasks:
+ res = await asyncio.gather(*tasks, return_exceptions=True)
+ num_success = sum(glom(res, "*.success"))
+ if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
+ logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
+ tasks = []
+
+
+async def chat_info(client: Client, chat_id: int) -> Chat:
+ if cache.get(f"chat-info-{chat_id}"):
+ return cache.get(f"chat-info-{chat_id}")
+ if chat_id == 1:
+ return Chat(id=1)
+ try:
+ logger.debug(f"Getting chat info for {chat_id}")
+ chat = await client.get_chat(int(chat_id))
+ except PeerIdInvalid:
+ return await chat_info(client, int(f"-100{chat_id}"))
+ except Exception:
+ chat = Chat(id=1)
+ cache.set(f"chat-info-{chat_id}", chat, ttl=3600) # cache for 1 hour
+ return chat
+
+
+async def get_table_name(client: Client, chat_id: str | int) -> str:
+ """Get D1 table name by chat id."""
+ if cache.get(f"tablename-{chat_id}"):
+ return cache.get(f"tablename-{chat_id}")
+ # get a default table name
+ chat = await chat_info(client, int(chat_id))
+ first_name = chat.first_name or ""
+ last_name = chat.last_name or ""
+ full_name = first_name + last_name
+ chat_title = full_name or chat.title or ""
+ slim_cid = str(chat_id).removeprefix("-100")
+ default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
+
+ # get database id, and create database if not exists
+ database_id = await create_d1_database(name=HISTORY.D1_DATABASE)
+ if not database_id:
+ await create_d1_table(default_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+ await create_table_index(slim_cid, database_id, default_name)
+ return default_name
+
+ # find the table name based on chat id
+ sql = "SELECT name FROM sqlite_master WHERE type='table';"
+ resp = await query_d1(sql, database_id)
+ table_names = flatten(glom(resp, "result.*.results.*.name", default=[]))
+ table_name = next((x for x in table_names if x.startswith(slim_cid + "-")), default_name)
+ cache.set(f"tablename-{chat_id}", table_name, ttl=0)
+
+ # create table and index
+ await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+ await create_table_index(slim_cid, database_id, table_name)
+ return table_name
+
+
+async def create_table_index(slim_cid: str, database_id: str, table_name: str) -> None:
+ # get all index names
+ sql = "SELECT name FROM sqlite_master WHERE type='index';"
+ resp = await query_d1(sql, database_id)
+ indexs = flatten(glom(resp, "result.*.results.*.name", default=[]))
+
+ # create index if not exists
+ idx_names = [x for x in INDEX_NAMES if f"idx_{slim_cid}_{x}" not in indexs]
+ if not idx_names:
+ return
+
+ # 创建标准索引
+ for idx_name in idx_names:
+ logger.debug(f"Creating index on {table_name} for {idx_name}")
+ sql = f'CREATE INDEX IF NOT EXISTS idx_{slim_cid}_{idx_name} ON "{table_name}"({idx_name})'
+ await query_d1(sql, database_id, silent=True)
+
+ """创建 FTS5 虚拟表
+ -- content=table_name 指明关联的原表
+ -- content_rowid=mid 指明原表的行 ID 列是 mid
+ -- segmented 是我们要索引的列
+ -- tokenize='unicode61' 使用 unicode61 分词器
+ """
+ sql = f"""CREATE VIRTUAL TABLE IF NOT EXISTS fts_{slim_cid} USING fts5(segmented, content="{table_name}", content_rowid=mid, tokenize="unicode61");"""
+ await query_d1(sql, database_id, silent=True)
+
+ """将现有数据从原表复制到 FTS 表
+ 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=mid 指定的列) 和 segmented
+ 从原表中选择 mid 和 segmented 列。mid 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
+ """
+ sql = f"INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT mid, segmented FROM '{table_name}' WHERE mid NOT IN (SELECT rowid FROM fts_{slim_cid});"
+ await query_d1(sql, database_id, silent=True)
+
+ """维护 FTS 表
+ 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
+ 在原表插入、删除、更新时, 同步更新 FTS 表
+ """
+ # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ai AFTER INSERT ON '{table_name}' BEGIN INSERT INTO fts_{slim_cid} (rowid, segmented) VALUES (NEW.mid, NEW.segmented); END;"
+ await query_d1(sql, database_id, silent=True)
+
+ # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ad AFTER DELETE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid; END;"
+ await query_d1(sql, database_id, silent=True)
+
+ # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
+ # FTS5 的更新通常是先删除旧的, 再插入新的
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_au AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid AND OLD.segmented <> NEW.segmented; INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT NEW.mid, NEW.segmented WHERE OLD.segmented <> NEW.segmented; END;"
+ await query_d1(sql, database_id, silent=True)
src/history/sync.py
@@ -1,304 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import asyncio
-import json
-from datetime import datetime
-from pathlib import Path
-from zoneinfo import ZoneInfo
-
-from cutword import Cutter
-from glom import Coalesce, flatten, glom
from loguru import logger
from pyrogram.client import Client
-from pyrogram.errors import PeerIdInvalid
-from pyrogram.types import Chat, Message
-
-from config import DOWNLOAD_DIR, HISTORY, TZ, cache
-from database.d1 import create_d1_database, create_d1_table, query_d1
-from messages.parser import parse_msg
-from permission import check_save_d1
-from utils import i_am_bot
-
-# ruff: noqa: S608
-
-DB_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, user TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, uid INTEGER, segmented TEXT"
-INDEX_NAMES = ["time", "uid"]
-cutter = Cutter()
+from pyrogram.types import Message
+from config import HISTORY, cache
+from history.d1 import backup_chat_history_to_d1, sync_history_to_d1
+from history.turso import backup_chat_history_to_turso, sync_history_to_turso
-async def sync_history_to_d1(client: Client, message: Message) -> None:
- if not HISTORY.D1_ENABLE:
- return
- info = parse_msg(message, silent=True)
- if not check_save_d1(info["ctype"], info["cid"]):
- return
- table_name = await get_table_name(client, info["cid"])
- records = {
- "mid": info["mid"],
- "time": info["time"],
- "user": info["full_name"],
- "content": message.content, # text or edited text
- "mtype": info["mtype"],
- "uid": info["uid"],
- "filename": info["file_name"],
- "mime": info["mime_type"],
- "urls": "\n\n".join(info["entity_urls"]),
- "reply": message.reply_to_message_id,
- "segmented": " ".join(cutter.cutword(message.content)),
- }
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- await query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True)
-
-async def backup_chat_history(client: Client, chat_id: str | int) -> None:
- if not HISTORY.D1_ENABLE:
+async def sync_chat_history(client: Client, message: Message) -> None:
+ if not HISTORY.ENABLE:
return
- if await i_am_bot(client):
+ if HISTORY.ENGINE.upper() == "D1":
+ await sync_history_to_d1(client, message)
+ if HISTORY.ENGINE.upper() == "TURSO":
+ await sync_history_to_turso(client, message)
+
+
+async def backup_chat_history(
+ client: Client,
+ chats: str = HISTORY.PERIODICALLY_BACKUP_CHATS,
+ hours: float = HISTORY.BACKUP_CHATS_HOURS,
+) -> None:
+ if not HISTORY.ENABLE:
return
- chat = await chat_info(client, int(chat_id))
- table_name = await get_table_name(client, chat_id)
- # find last message id
- sql = f'SELECT mid FROM "{table_name}" ORDER BY mid DESC LIMIT 1;'
- resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
- last_mid = glom(resp, "result.0.results.0.mid", default=1)
- concurrency = 200
- tasks = []
- mids = set() # to avoid duplicate
- async for message in client.get_chat_history(chat.id, offset_id=last_mid, reverse=True): # type: ignore
- if not isinstance(message, Message) or message.empty:
- continue
- info = parse_msg(message, silent=True)
- if info["mid"] in mids:
- continue
- mids.add(info["mid"])
- records = {
- "mid": info["mid"],
- "time": info["time"],
- "user": info["full_name"],
- "content": info["text"],
- "mtype": info["mtype"],
- "uid": info["uid"],
- "filename": info["file_name"],
- "mime": info["mime_type"],
- "urls": "\n\n".join(info["entity_urls"]),
- "reply": message.reply_to_message_id,
- "segmented": " ".join(cutter.cutword(info["text"])),
- }
- logger.trace(f"Syncing message {table_name}: {info['mid']}")
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
- if len(tasks) == concurrency:
- res = await asyncio.gather(*tasks, return_exceptions=True)
- num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
- mids = set()
- if tasks:
- res = await asyncio.gather(*tasks, return_exceptions=True)
- num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
- mids = set()
-
-async def upload_exported_history(client: Client, path: str | Path | None = None) -> None:
- if not HISTORY.D1_ENABLE:
+ if cache.get("backup_chat_history"):
return
- if path is None:
- path = Path(DOWNLOAD_DIR) / "result.json"
- path = Path(path)
- if not path.is_file():
- return
-
- def parse_text(texts: list) -> str:
- if isinstance(texts, str):
- return texts
- text = ""
- for x in texts:
- text += x if isinstance(x, str) else x.get("text", "")
- return text
-
- def parse_urls(entities: list) -> str:
- urls = [glom(x, Coalesce("href", "text")) for x in entities if x["type"] in {"link", "text_link"}]
- return "\n\n".join(urls)
-
- with path.open("r") as f: # noqa: ASYNC230
- data = json.load(f)
- mtypes = {
- "text": "text",
- "photo": "photo",
- "animation": "animation",
- "audio_file": "audio",
- "sticker": "sticker",
- "voice_message": "voice",
- "video_message": "video",
- "video_file": "video",
- }
- table_name = await get_table_name(client, data["id"])
- # find all message_ids
- sql = f'SELECT mid FROM "{table_name}" ORDER BY mid;'
- resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
- saved_ids = glom(resp, "result.0.results.*.mid", default=[])
- concurrency = 200
- tasks = []
- for info in data["messages"]: # type: ignore
- if info["id"] in saved_ids:
- continue
- if info["type"] != "message":
- continue
- if info["date_unixtime"] == "0":
- continue
- if "media_type" not in info: # guess mtype
- if "photo" in info:
- info["media_type"] = "photo"
- if "video/" in info.get("mime_type", ""):
- info["media_type"] = "video_file"
-
- dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
- uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
- user = info["from"]
- if user == data["name"] and data["type"] in ["public_channel", "private_channel"]: # user is not shown
- user = ""
- uid = 1
-
- content = parse_text(info.get("text", []))
- records = {
- "mid": info["id"],
- "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
- "user": user,
- "content": parse_text(info.get("text", [])),
- "mtype": mtypes[info.get("media_type", "text")],
- "uid": uid,
- "filename": info.get("file_name", ""),
- "mime": info.get("mime_type", ""),
- "urls": parse_urls(info.get("text_entities", [])),
- "reply": info.get("reply_to_message_id"),
- "segmented": " ".join(cutter.cutword(content)),
- }
- logger.trace(f"Syncing message {table_name}: {info['id']}")
- keys = ", ".join(records)
- values = ", ".join(["?" for _ in range(len(records))])
- updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
- sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
- tasks.append(query_d1(sql, db_name=HISTORY.D1_DATABASE, params=list(records.values()), silent=True))
- if len(tasks) == concurrency:
- res = await asyncio.gather(*tasks, return_exceptions=True)
- num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
-
- if tasks:
- res = await asyncio.gather(*tasks, return_exceptions=True)
- num_success = sum(glom(res, "*.success"))
- if sync_ids := glom(res, "*.result.0.meta.last_row_id", default=0):
- logger.success(f"Synced {num_success} messages to D1, {min(sync_ids)} -> {max(sync_ids)}")
- tasks = []
+ cache.set("backup_chat_history", 1, ttl=12 * 3600) # backup every 12 hours
-
-async def chat_info(client: Client, chat_id: int) -> Chat:
- if cache.get(f"chat-info-{chat_id}"):
- return cache.get(f"chat-info-{chat_id}")
- if chat_id == 1:
- return Chat(id=1)
- try:
- logger.debug(f"Getting chat info for {chat_id}")
- chat = await client.get_chat(int(chat_id))
- except PeerIdInvalid:
- return await chat_info(client, int(f"-100{chat_id}"))
- except Exception:
- chat = Chat(id=1)
- cache.set(f"chat-info-{chat_id}", chat, ttl=3600) # cache for 1 hour
- return chat
-
-
-async def get_table_name(client: Client, chat_id: str | int) -> str:
- """Get D1 table name by chat id."""
- if cache.get(f"tablename-{chat_id}"):
- return cache.get(f"tablename-{chat_id}")
- # get a default table name
- chat = await chat_info(client, int(chat_id))
- first_name = chat.first_name or ""
- last_name = chat.last_name or ""
- full_name = first_name + last_name
- chat_title = full_name or chat.title or ""
- slim_cid = str(chat_id).removeprefix("-100")
- default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
-
- # get database id, and create database if not exists
- database_id = await create_d1_database(name=HISTORY.D1_DATABASE)
- if not database_id:
- await create_d1_table(default_name, DB_COLUMNS, HISTORY.D1_DATABASE)
- await create_table_index(slim_cid, database_id, default_name)
- return default_name
-
- # find the table name based on chat id
- sql = "SELECT name FROM sqlite_master WHERE type='table';"
- resp = await query_d1(sql, database_id)
- table_names = flatten(glom(resp, "result.*.results.*.name", default=[]))
- table_name = next((x for x in table_names if x.startswith(slim_cid + "-")), default_name)
- cache.set(f"tablename-{chat_id}", table_name, ttl=0)
-
- # create table and index
- await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE)
- await create_table_index(slim_cid, database_id, table_name)
- return table_name
-
-
-async def create_table_index(slim_cid: str, database_id: str, table_name: str) -> None:
- # get all index names
- sql = "SELECT name FROM sqlite_master WHERE type='index';"
- resp = await query_d1(sql, database_id)
- indexs = flatten(glom(resp, "result.*.results.*.name", default=[]))
-
- # create index if not exists
- idx_names = [x for x in INDEX_NAMES if f"idx_{slim_cid}_{x}" not in indexs]
- if not idx_names:
+ chat_ids = [x.strip() for x in chats.split(",") if x.strip()]
+ if not chat_ids:
return
-
- # 创建标准索引
- for idx_name in idx_names:
- logger.debug(f"Creating index on {table_name} for {idx_name}")
- sql = f'CREATE INDEX IF NOT EXISTS idx_{slim_cid}_{idx_name} ON "{table_name}"({idx_name})'
- await query_d1(sql, database_id, silent=True)
-
- """创建 FTS5 虚拟表
- -- content=table_name 指明关联的原表
- -- content_rowid=mid 指明原表的行 ID 列是 mid
- -- segmented 是我们要索引的列
- -- tokenize='unicode61' 使用 unicode61 分词器
- """
- sql = f"""CREATE VIRTUAL TABLE IF NOT EXISTS fts_{slim_cid} USING fts5(segmented, content="{table_name}", content_rowid=mid, tokenize="unicode61");"""
- await query_d1(sql, database_id, silent=True)
-
- """将现有数据从原表复制到 FTS 表
- 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=mid 指定的列) 和 segmented
- 从原表中选择 mid 和 segmented 列。mid 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
- """
- sql = f"INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT mid, segmented FROM '{table_name}' WHERE mid NOT IN (SELECT rowid FROM fts_{slim_cid});"
- await query_d1(sql, database_id, silent=True)
-
- """维护 FTS 表
- 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
- 在原表插入、删除、更新时, 同步更新 FTS 表
- """
- # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ai AFTER INSERT ON '{table_name}' BEGIN INSERT INTO fts_{slim_cid} (rowid, segmented) VALUES (NEW.mid, NEW.segmented); END;"
- await query_d1(sql, database_id, silent=True)
-
- # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ad AFTER DELETE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid; END;"
- await query_d1(sql, database_id, silent=True)
-
- # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
- # FTS5 的更新通常是先删除旧的, 再插入新的
- sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_au AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid AND OLD.segmented <> NEW.segmented; INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT NEW.mid, NEW.segmented WHERE OLD.segmented <> NEW.segmented; END;"
- await query_d1(sql, database_id, silent=True)
+ for cid in chat_ids:
+ logger.info(f"Backup chat history: {cid}")
+ if HISTORY.ENGINE.upper() == "D1":
+ await backup_chat_history_to_d1(client, cid, hours)
+ elif HISTORY.ENGINE.upper() == "TURSO":
+ await backup_chat_history_to_turso(client, cid, hours)
src/history/turso.py
@@ -0,0 +1,319 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import json
+from datetime import datetime, timedelta
+from pathlib import Path
+from zoneinfo import ZoneInfo
+
+from glom import Coalesce, flatten, glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.errors import PeerIdInvalid
+from pyrogram.types import Chat, Message
+
+from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
+from database.turso import turso_create_table, turso_exec, turso_list_tables
+from messages.parser import parse_msg
+from permission import check_save_history
+from utils import i_am_bot, nowdt
+
+# ruff: noqa: S608
+
+DB_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, user TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, uid INTEGER, segmented TEXT"
+INDEX_NAMES = ["time", "uid"]
+SQL_TYPES = {"str": "text", "int": "integer", "float": "float", "nonetype": "null"}
+
+
+async def sync_history_to_turso(client: Client, message: Message) -> None:
+ if not HISTORY.TURSO_ENABLE:
+ return
+
+ info = parse_msg(message, silent=True)
+ if not check_save_history(info["ctype"], info["cid"]):
+ return
+ table_name = await get_table_name(client, info["cid"])
+ records = {
+ "mid": info["mid"],
+ "mtype": info["mtype"],
+ "time": info["time"],
+ "user": info["full_name"],
+ "content": message.content, # text or edited text
+ "filename": info["file_name"],
+ "urls": "\n\n".join(info["entity_urls"]),
+ "reply": message.reply_to_message_id,
+ "mime": info["mime_type"],
+ "uid": info["uid"],
+ "segmented": " ".join(cutter.cutword(message.content)),
+ }
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ args = [{"type": SQL_TYPES[type(x).__name__.lower()], "value": str(x) if isinstance(x, (int, float)) else x} for x in records.values()]
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ await turso_exec([{"type": "execute", "stmt": {"sql": sql, "args": args}}], db_name=HISTORY.TURSO_DATABASE, silent=True, retry=2)
+
+
+async def backup_chat_history_to_turso(client: Client, chat_id: str | int, hours: float = HISTORY.BACKUP_CHATS_HOURS) -> None:
+ if not HISTORY.TURSO_ENABLE:
+ return
+ if await i_am_bot(client):
+ return
+ chat = await chat_info(client, int(chat_id))
+ table_name = await get_table_name(client, chat_id)
+
+ # find message ids in this time range
+ now = nowdt(TZ)
+ begin_dt = now - timedelta(hours=hours)
+ begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
+ end_time = now.strftime("%Y-%m-%d %H:%M:%S")
+ sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}";'
+ resp = await turso_exec([{"type": "execute", "stmt": {"sql": sql}}], db_name=HISTORY.TURSO_DATABASE, silent=True)
+ saved_mids = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
+ saved_mids = [int(x) for x in saved_mids]
+ logger.info(f"Found {len(saved_mids)} messages in Turso. Rows read: {glom(resp, 'results.0.response.result.rows_read', default=1)}")
+ concurrency = 200
+ statements = []
+ async for message in client.get_chat_history(chat.id): # type: ignore
+ if not isinstance(message, Message) or message.empty:
+ continue
+ info = parse_msg(message, silent=True)
+ if info["mid"] in saved_mids:
+ continue
+ if info["time"] < begin_time:
+ break
+ records = {
+ "mid": info["mid"],
+ "mtype": info["mtype"],
+ "time": info["time"],
+ "user": info["full_name"],
+ "content": message.content,
+ "filename": info["file_name"],
+ "urls": "\n\n".join(info["entity_urls"]),
+ "reply": message.reply_to_message_id,
+ "mime": info["mime_type"],
+ "uid": info["uid"],
+ "segmented": " ".join(cutter.cutword(info["text"])),
+ }
+ logger.trace(f"Syncing {table_name}: {info['mid']}")
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ args = [{"type": SQL_TYPES[type(x).__name__.lower()], "value": str(x) if isinstance(x, (int, float)) else x} for x in records.values()]
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ statements.append({"type": "execute", "stmt": {"sql": sql, "args": args}})
+ if len(statements) == concurrency:
+ resp = await turso_exec(statements, db_name=HISTORY.TURSO_DATABASE, silent=True, retry=2)
+ num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"]) - 1
+ if sync_ids := glom(resp, "results.**.last_insert_rowid", default=[0]):
+ logger.success(f"Synced {num_success} messages to Turso, {min(sync_ids)} -> {max(sync_ids)}. {info['time']}")
+ statements = []
+
+ if statements:
+ resp = await turso_exec(statements, db_name=HISTORY.TURSO_DATABASE, silent=True, retry=2)
+ num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"]) - 1
+ if sync_ids := glom(resp, "results.**.last_insert_rowid", default=[0]):
+ logger.success(f"Synced {num_success} messages to Turso, {min(sync_ids)} -> {max(sync_ids)}. {info['time']}")
+
+
+async def upload_exported_history_to_turso(client: Client, path: str | Path | None = None) -> None:
+ if not HISTORY.TURSO_ENABLE:
+ return
+ if path is None:
+ path = Path(DOWNLOAD_DIR) / "result.json"
+ path = Path(path)
+ if not path.is_file():
+ return
+
+ def parse_text(texts: list) -> str:
+ if isinstance(texts, str):
+ return texts
+ text = ""
+ for x in texts:
+ text += x if isinstance(x, str) else x.get("text", "")
+ return text
+
+ def parse_urls(entities: list) -> str:
+ urls = [glom(x, Coalesce("href", "text")) for x in entities if x["type"] in {"link", "text_link"}]
+ return "\n\n".join(urls)
+
+ with path.open("r") as f: # noqa: ASYNC230
+ data = json.load(f)
+ logger.info(f"Found {len(data['messages'])} messages in json file")
+ mtypes = {
+ "text": "text",
+ "photo": "photo",
+ "animation": "animation",
+ "audio_file": "audio",
+ "sticker": "sticker",
+ "voice_message": "voice",
+ "video_message": "video",
+ "video_file": "video",
+ }
+ table_name = await get_table_name(client, data["id"])
+ # find all message_ids
+ resp = await turso_exec(
+ [{"type": "execute", "stmt": {"sql": f'SELECT mid FROM "{table_name}";'}}],
+ db_name=HISTORY.TURSO_DATABASE,
+ silent=True,
+ )
+ saved_ids = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
+ saved_ids = [int(x) for x in saved_ids]
+ logger.info(f"Found {len(saved_ids)} messages in Turso. Rows read: {glom(resp, 'results.0.response.result.rows_read', default=1)}")
+ last_id = max(saved_ids, default=0)
+ logger.info(f"Found last message at {last_id}")
+ concurrency = 5000
+ statements = []
+ for info in data["messages"]: # type: ignore
+ if info["id"] in saved_ids:
+ continue
+ if info["type"] != "message":
+ continue
+ if info["date_unixtime"] == "0":
+ continue
+ if "media_type" not in info: # guess mtype
+ if "photo" in info:
+ info["media_type"] = "photo"
+ if "video/" in info.get("mime_type", ""):
+ info["media_type"] = "video_file"
+
+ dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
+ uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
+ user = info["from"]
+ if user == data["name"] and data["type"] in ["public_channel", "private_channel"]: # user is not shown
+ user = ""
+ uid = 1
+
+ content = parse_text(info.get("text", []))
+ records = {
+ "mid": info["id"],
+ "mtype": mtypes[info.get("media_type", "text")],
+ "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
+ "user": user,
+ "content": parse_text(info.get("text", [])),
+ "filename": info.get("file_name", ""),
+ "urls": parse_urls(info.get("text_entities", [])),
+ "reply": info.get("reply_to_message_id"),
+ "mime": info.get("mime_type", ""),
+ "uid": uid,
+ "segmented": " ".join(cutter.cutword(content)),
+ }
+ # logger.debug(f"Syncing message {table_name}: {info['id']}")
+ keys = ", ".join(records)
+ values = ", ".join(["?" for _ in range(len(records))])
+ updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
+ args = [{"type": SQL_TYPES[type(x).__name__.lower()], "value": str(x) if isinstance(x, (int, float)) else x} for x in records.values()]
+ sql = f'INSERT INTO "{table_name}" ({keys}) VALUES ({values}) ON CONFLICT (mid) DO UPDATE SET {updates};'
+ statements.append({"type": "execute", "stmt": {"sql": sql, "args": args}})
+ if len(statements) == concurrency:
+ resp = await turso_exec(statements, db_name=HISTORY.TURSO_DATABASE, silent=True, retry=2)
+ num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"]) - 1
+ if sync_ids := glom(resp, "results.**.last_insert_rowid", default=[0]):
+ logger.success(f"Synced {num_success} messages to Turso, {min(sync_ids)} -> {max(sync_ids)}. {dt.strftime('%Y-%m-%d %H:%M:%S')}")
+ statements = []
+ if statements:
+ resp = await turso_exec(statements, db_name=HISTORY.TURSO_DATABASE, silent=True, retry=2)
+ num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"]) - 1
+ if sync_ids := glom(resp, "results.**.last_insert_rowid", default=[0]):
+ logger.success(f"Synced {num_success} messages to Turso, {min(sync_ids)} -> {max(sync_ids)}")
+
+
+async def chat_info(client: Client, chat_id: int) -> Chat:
+ if cache.get(f"chat-info-{chat_id}"):
+ return cache.get(f"chat-info-{chat_id}")
+ if chat_id == 1:
+ return Chat(id=1)
+ try:
+ logger.debug(f"Getting chat info for {chat_id}")
+ chat = await client.get_chat(int(chat_id))
+ except PeerIdInvalid:
+ return await chat_info(client, int(f"-100{chat_id}"))
+ except Exception:
+ chat = Chat(id=1)
+ cache.set(f"chat-info-{chat_id}", chat, ttl=3600) # cache for 1 hour
+ return chat
+
+
+async def get_table_name(client: Client, chat_id: str | int) -> str:
+ """Get table name by chat id."""
+ if cache.get(f"tablename-{chat_id}"):
+ return cache.get(f"tablename-{chat_id}")
+ # get a default table name
+ chat = await chat_info(client, int(chat_id))
+ first_name = chat.first_name or ""
+ last_name = chat.last_name or ""
+ full_name = first_name + last_name
+ chat_title = full_name or chat.title or ""
+ slim_cid = str(chat_id).removeprefix("-100")
+ default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
+
+ # find the table name based on chat id
+ table_names = await turso_list_tables(HISTORY.TURSO_DATABASE, silent=True)
+ table_name = next((x for x in table_names if x.startswith(slim_cid + "-")), default_name)
+ cache.set(f"tablename-{chat_id}", table_name, ttl=0)
+
+ # create table and index
+ table_names = await turso_list_tables(HISTORY.TURSO_DATABASE, silent=True)
+ if table_name in table_names:
+ return table_name
+ await turso_create_table(table_name, DB_COLUMNS, HISTORY.TURSO_DATABASE)
+ await create_table_index(slim_cid, table_name)
+ return table_name
+
+
+async def create_table_index(slim_cid: str, table_name: str) -> None:
+ # get all index names
+ resp = await turso_exec(
+ [{"type": "execute", "stmt": {"sql": "SELECT name FROM sqlite_master WHERE type='index';"}}],
+ db_name=HISTORY.TURSO_DATABASE,
+ silent=True,
+ )
+ indexs = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
+
+ # create index if not exists
+ idx_names = [x for x in INDEX_NAMES if f"idx_{slim_cid}_{x}" not in indexs]
+ if not idx_names:
+ return
+
+ # 创建标准索引
+ for idx_name in idx_names:
+ logger.debug(f"Creating index on {table_name} for {idx_name}")
+ resp = await turso_exec(
+ [{"type": "execute", "stmt": {"sql": f'CREATE INDEX IF NOT EXISTS idx_{table_name}_{idx_name} ON "{table_name}"({idx_name})'}}],
+ db_name=HISTORY.TURSO_DATABASE,
+ silent=True,
+ )
+
+ statements = []
+ """创建 FTS5 虚拟表
+ -- content=table_name 指明关联的原表
+ -- content_rowid=mid 指明原表的行 ID 列是 mid
+ -- segmented 是我们要索引的列
+ -- tokenize='unicode61' 使用 unicode61 分词器
+ """
+ sql = f"""CREATE VIRTUAL TABLE IF NOT EXISTS fts_{slim_cid} USING fts5(segmented, content="{table_name}", content_rowid=mid, tokenize="unicode61");"""
+ statements.append({"type": "execute", "stmt": {"sql": sql}})
+
+ """将现有数据从原表复制到 FTS 表
+ 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=mid 指定的列) 和 segmented
+ 从原表中选择 mid 和 segmented 列。mid 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
+ """
+ sql = f"INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT mid, segmented FROM '{table_name}' WHERE mid NOT IN (SELECT rowid FROM fts_{slim_cid});"
+ statements.append({"type": "execute", "stmt": {"sql": sql}})
+
+ """维护 FTS 表
+ 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
+ 在原表插入、删除、更新时, 同步更新 FTS 表
+ """
+ # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ai AFTER INSERT ON '{table_name}' BEGIN INSERT INTO fts_{slim_cid} (rowid, segmented) VALUES (NEW.mid, NEW.segmented); END;"
+ statements.append({"type": "execute", "stmt": {"sql": sql}})
+
+ # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_ad AFTER DELETE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid; END;"
+ statements.append({"type": "execute", "stmt": {"sql": sql}})
+
+ # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
+ # FTS5 的更新通常是先删除旧的, 再插入新的
+ sql = f"CREATE TRIGGER IF NOT EXISTS trigger_{slim_cid}_au AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM fts_{slim_cid} WHERE rowid = OLD.mid AND OLD.segmented <> NEW.segmented; INSERT INTO fts_{slim_cid} (rowid, segmented) SELECT NEW.mid, NEW.segmented WHERE OLD.segmented <> NEW.segmented; END;"
+ statements.append({"type": "execute", "stmt": {"sql": sql}})
+ await turso_exec(statements, db_name=HISTORY.TURSO_DATABASE, silent=True)
src/config.py
@@ -6,10 +6,13 @@ from pathlib import Path
from typing import ClassVar
from cacheout import Cache
+from cutword import Cutter
# ruff: noqa: RUF001
+# init some global instances
cache = Cache(ttl=0, maxsize=2048)
semaphore = asyncio.Semaphore(8) # max 8 concurrent downloads
+cutter = Cutter()
DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", Path(__file__).parent.joinpath("downloads").as_posix())
FILE_SERVER = os.getenv("FILE_SERVER", "") # expose the download dir to internet (optional). for example: https://server.com/dir
@@ -201,6 +204,12 @@ class DB:
class HISTORY:
+ ENABLE = os.getenv("HISTORY_ENABLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+ ENGINE = os.getenv("HISTORY_ENGINE", "turso") # turso or D1
+ TURSO_ENABLE = os.getenv("HISTORY_TURSO_ENABLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+ TURSO_DATABASE = os.getenv("HISTORY_TURSO_DATABASE", "bennybot-history")
+ PERIODICALLY_BACKUP_CHATS = os.getenv("HISTORY_PERIODICALLY_BACKUP_CHATS", "") # comma separated chat ids to include (without `-100` prefix)
+ BACKUP_CHATS_HOURS = float(os.getenv("HISTORY_BACKUP_CHATS_HOURS", "24")) # hours to backup chats
D1_ENABLE = os.getenv("HISTORY_D1_ENABLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
D1_DATABASE = os.getenv("HISTORY_D1_DATABASE", "bennybot-history")
INCLUDE_CHATS = os.getenv("HISTORY_INCLUDE_CHATS", "") # "all" or comma separated chat ids to include (without `-100` prefix)
src/handler.py
@@ -11,7 +11,7 @@ from bridge.ocr import send_to_ocr_bridge
from config import ENABLE, PREFIX, PROXY
from danmu.entrypoint import query_danmu
from database.database import del_db
-from history.sync import sync_history_to_d1
+from history.sync import sync_chat_history
from llm.gpt import gpt_response
from llm.summary import ai_summary
from messages.parser import parse_msg
@@ -48,7 +48,7 @@ async def handle_utilities(
asr: bool = True,
audio: bool = True,
danmu: bool = True,
- save_d1: bool = True,
+ save_history: bool = True,
google: bool = True,
ocr: bool = True,
price: bool = True,
@@ -74,7 +74,7 @@ async def handle_utilities(
asr (bool, optional): Enable ASR. Defaults to True.
audio (bool, optional): Enable Video -> Audio. Defaults to True.
danmu (bool, optional): Enable Query Danmu database. Defaults to True.
- save_d1 (bool, optional): Enable Save message to D1. Defaults to True.
+ save_history (bool, optional): Enable save chat history. Defaults to True.
google (bool, optional): Enable Google Search. Defaults to True.
ytb (bool, optional): Enable YouTube Search. Defaults to True.
subtitle (bool, optional): Enable YouTube subtitle. Defaults to True.
@@ -86,8 +86,8 @@ async def handle_utilities(
show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
"""
- if kwargs.get("only_d1"):
- await sync_history_to_d1(client, message)
+ if kwargs.get("only_history"):
+ await sync_chat_history(client, message)
return
if kwargs.get("disabled"):
return
@@ -119,8 +119,8 @@ async def handle_utilities(
await query_danmu(client, message, **kwargs) # /danmu
if raw_img:
await convert_raw_img_file(client, message, **kwargs)
- if save_d1:
- await sync_history_to_d1(client, message)
+ if save_history:
+ await sync_chat_history(client, message)
async def handle_social_media(
@@ -161,7 +161,7 @@ async def handle_social_media(
show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
"""
- if kwargs.get("only_d1") or kwargs.get("disabled"):
+ if kwargs.get("only_history") or kwargs.get("disabled"):
return None
kwargs |= {"target_chat": target_chat, "reply_msg_id": reply_msg_id, "show_progress": show_progress, "detail_progress": detail_progress}
if not ENABLE.SEND_AS_REPLY:
src/main.py
@@ -24,7 +24,7 @@ from bridge.social import forward_social_media_results
from config import DAILY_MESSAGES, DEVICE_NAME, ENABLE, PROXY, TOKEN, TZ, cache
from danmu.sync import sync_server_to_r2
from handler import handle_social_media, handle_utilities
-from history.sync import sync_history_to_d1
+from history.sync import backup_chat_history, sync_chat_history
from llm.summary import daily_summary
from llm.utils import clean_gemini_files
from messages.parser import parse_msg
@@ -89,7 +89,7 @@ async def main():
@app.on_edited_message()
async def edited(client: Client, message: Message):
- await sync_history_to_d1(client, message)
+ await sync_chat_history(client, message)
if ENABLE.CRONTAB:
scheduler = AsyncIOScheduler()
@@ -126,6 +126,7 @@ async def scheduling(client: Client):
await sync_server_to_r2(qtype="发言")
await sync_server_to_r2(qtype="弹幕")
await clean_gemini_files()
+ await backup_chat_history(client)
if __name__ == "__main__":
src/permission.py
@@ -27,10 +27,10 @@ async def check_permission(client: Client, message: Message) -> dict:
Some times, we only need to save message to D1, but disable other tools.
"""
# check if we should save this message to D1
- save_d1 = check_save_d1(ctype=ctype, cid=message.chat.id)
- if permission["disabled"] and save_d1: # only save msg to D1, disable others
- permission["only_d1"] = True
- permission["save_d1"] = save_d1
+ save_history = check_save_history(ctype=ctype, cid=message.chat.id)
+ if permission["disabled"] and save_history: # only save msg to D1, disable others
+ permission["only_history"] = True
+ permission["save_history"] = save_history
return permission
@@ -194,7 +194,7 @@ def check_service(cid: int | str, ctype: str) -> dict:
return permission
-def check_save_d1(ctype: str, cid: int | str) -> bool:
+def check_save_history(ctype: str, cid: int | str) -> bool:
# ruff: noqa: SIM103
cid = slim_cid(cid)
if true(os.getenv(f"HISTORY_IGNORE_{cid}")):