main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import json
5import os
6from datetime import datetime, timedelta
7from pathlib import Path
8from typing import Literal
9from zoneinfo import ZoneInfo
10
11from glom import Coalesce, glom
12from loguru import logger
13from pyrogram.client import Client
14from pyrogram.types import Message
15
16from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
17from database.d1 import create_d1_table, insert_d1, query_d1
18from history.utils import CHAT_COLUMNS, MSG_COLUMNS, MSG_INDEXES, USER_COLUMNS, USER_INDEXES, can_delete_history, check_save_history, fine_grained_check, get_chat
19from messages.parser import parse_chat, parse_msg
20from utils import i_am_bot, nowdt, slim_cid, to_int, true
21
22
23async def sync_history_to_d1(client: Client, message: Message) -> None:
24 """Sync received messages to D1 database.
25
26 1. save the user info to table `userinfo`
27 2. save the chat info to table `chatinfo`
28 3. save the message to table `{cid}-{ctitle}`
29 """
30 if not HISTORY.D1_ENABLE:
31 return
32 if isinstance(message, list): # this is deleted messages
33 await delete_messages(message)
34 return
35 info = parse_msg(message, silent=True, use_cache=False)
36 if not check_save_history(info["ctype"], info["cid"]) or not fine_grained_check(info) or message.service:
37 return
38 await save_userinfo_to_d1(client, info)
39 chatinfo = await save_chatinfo_to_d1(client, info)
40 records = {
41 "mid": info["mid"],
42 "mtype": info["mtype"],
43 "time": info["time"],
44 "fullname": info["full_name"],
45 "content": message.content, # text or edited text
46 "filename": info["file_name"],
47 "urls": "\n\n".join(info["entity_urls"]),
48 "reply": message.reply_to_message_id,
49 "mime": info["mime_type"],
50 "user": info["full_name"].replace(" ", ""),
51 "handle": info["handle"],
52 "uid": info["uid"],
53 "gid": info["media_group_id"],
54 "segmented": " ".join(cutter.cutword(message.content)),
55 }
56 await query_d1(**insert_d1(chatinfo["tablename"], records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True)
57
58
59async def delete_messages(messages: Message | list[Message]) -> None:
60 """Delete messages from D1 database."""
61 if not isinstance(messages, list):
62 messages = [messages]
63 for message in messages:
64 cid = glom(message, "chat.id", default=0) or 0
65 mid = glom(message, "id", default=0) or 0
66 ctype = glom(message, "chat.type.name", default="") or ""
67 if not check_save_history(ctype, cid) or message.service:
68 return
69 chatinfo = await get_d1_chatinfo(cid)
70 if not chatinfo:
71 continue
72 tablename = chatinfo["tablename"]
73 resp = await query_d1(f"SELECT * FROM '{tablename}' WHERE mid={mid};", db_name=HISTORY.D1_DATABASE, silent=True)
74 uid = glom(resp, "result.0.results.0.uid", default=0) or 0
75 if not uid:
76 continue
77 if can_delete_history(cid, uid):
78 logger.warning(f"Delete message Chat={cid}, ID={mid}: {glom(resp, 'result.0.results.0')}")
79 await query_d1(f"DELETE FROM '{tablename}' WHERE mid={mid};", db_name=HISTORY.D1_DATABASE, silent=True)
80
81
82async def backup_chat_history_to_d1(
83 client: Client,
84 chat_id: str | int,
85 hours: float = HISTORY.BACKUP_CHATS_HOURS,
86 *,
87 start_from: Literal["latest", "oldest"] = "latest",
88 max_sync: float = float("inf"),
89) -> None:
90 """Backup chat history to D1 database.
91
92 If start_from is "oldest", find the minimum message id of this chat, then use this mid as `offset_id` to retrieve messages.
93 """
94 if not HISTORY.D1_ENABLE:
95 return
96 if await i_am_bot(client):
97 return
98 chat = await get_chat(client, to_int(chat_id))
99 if chat.id == 0: # chat is not accessible
100 return
101 chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
102 if true(os.getenv(f"HISTORY_IGNORE_{chatinfo['cid']}")):
103 return
104 table_name = chatinfo["tablename"]
105 now = nowdt(TZ)
106 begin_dt = now - timedelta(hours=hours)
107 begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
108 if start_from == "oldest":
109 sql = f'SELECT mid FROM "{table_name}" ORDER BY mid ASC LIMIT 1'
110 resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
111 offset_id = glom(resp, "result.0.results.0.mid", default=1)
112 saved_mids = {int(offset_id)}
113 else:
114 # find message ids in this time range
115 end_time = now.strftime("%Y-%m-%d %H:%M:%S")
116 sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}"'
117 resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
118 saved_mids = glom(resp, "result.0.results.*.mid", default=[])
119 saved_mids = {int(x) for x in saved_mids}
120 offset_id = 0 # retrieve from latest message
121 logger.info(f"Found {len(saved_mids)} messages of {table_name} in D1, Time >= {begin_time}")
122 concurrency = 100
123 num_sync = 0
124 tasks = []
125 real_cid = chatinfo["chandle"] or (int(chatinfo["cid"]) if chatinfo["ctype"] in ["BOT", "PRIVATE"] else int(f"-100{chatinfo['cid']}"))
126 async for message in client.get_chat_history(real_cid, max_id=offset_id): # type: ignore
127 if not isinstance(message, Message) or message.empty or message.service or message.id in saved_mids:
128 continue
129 info = parse_msg(message, silent=True, use_cache=False)
130 if info["time"] < begin_time:
131 break
132 if num_sync >= max_sync:
133 break
134 if not fine_grained_check(info):
135 continue
136 num_sync += 1
137 records = {
138 "mid": info["mid"],
139 "mtype": info["mtype"],
140 "time": info["time"],
141 "fullname": info["full_name"],
142 "content": message.content,
143 "filename": info["file_name"],
144 "urls": "\n\n".join(info["entity_urls"]),
145 "reply": message.reply_to_message_id,
146 "mime": info["mime_type"],
147 "user": info["full_name"].replace(" ", ""),
148 "handle": info["handle"],
149 "uid": info["uid"],
150 "gid": info["media_group_id"],
151 "segmented": " ".join(cutter.cutword(info["text"])),
152 }
153 logger.trace(f"Syncing {table_name}: {info['mid']} - {info['time']}")
154 tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
155 if len(tasks) == concurrency:
156 res = await asyncio.gather(*tasks, return_exceptions=True)
157 num_success = sum(glom(res, "*.success"))
158 logger.success(f"Synced {num_success} messages to D1")
159 tasks = []
160
161 if tasks:
162 res = await asyncio.gather(*tasks, return_exceptions=True)
163 num_success = sum(glom(res, "*.success"))
164 logger.success(f"Synced {num_success} messages to D1")
165
166
167async def upload_exported_history_to_d1(client: Client, path: str | Path | None = None) -> None:
168 if not HISTORY.D1_ENABLE:
169 return
170 if path is None:
171 path = Path(DOWNLOAD_DIR) / "result.json"
172 path = Path(path)
173 if not path.is_file():
174 return
175
176 def parse_text(texts: list) -> str:
177 if isinstance(texts, str):
178 return texts
179 text = ""
180 for x in texts:
181 text += x if isinstance(x, str) else x.get("text", "")
182 return text
183
184 def parse_urls(entities: list) -> str:
185 urls = [glom(x, Coalesce("href", "text")) for x in entities if x["type"] in {"link", "text_link"}]
186 return "\n\n".join(urls)
187
188 with path.open("r") as f: # noqa: ASYNC230
189 data = json.load(f)
190 logger.info(f"Found {len(data['messages'])} messages in json file")
191 """Since the exported history does not has media_group_id,
192 So we first process all messages and add media_group_id for it.
193 If two consecutive messages have the same `from_id` and `date_unixtime`,
194 and the message type is photo or video, these messages will be considered as a media group.
195 """
196 last_msg = {}
197 for idx, msg in enumerate(data["messages"]):
198 if all(msg.get(key) == last_msg.get(key) for key in ["from_id", "date_unixtime"]) and any(key in msg for key in ["photo", "thumbnail"]):
199 data["messages"][idx - 1]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
200 data["messages"][idx]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
201 last_msg = msg
202
203 mtypes = {
204 "audio_file": "audio",
205 "voice_message": "voice",
206 "video_message": "video",
207 "video_file": "video",
208 }
209 chat_id = data["id"]
210 chatinfo = await get_d1_chatinfo(chat_id)
211 if not chatinfo: # this chat is never synced
212 chat = await get_chat(client, int(chat_id))
213 chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
214 table_name = chatinfo["tablename"]
215 # find all message_ids
216 resp = await query_d1(f'SELECT mid FROM "{table_name}";', db_name=HISTORY.D1_DATABASE, silent=True)
217 saved_ids = glom(resp, "result.0.results.*.mid", default=[])
218 saved_ids = {int(x) for x in saved_ids}
219 concurrency = 100
220 tasks = []
221 for info in [msg for msg in data["messages"] if msg["id"] not in saved_ids]: # type: ignore
222 if info["type"] != "message":
223 continue
224 if info["date_unixtime"] == "0":
225 continue
226 if "media_type" not in info: # guess mtype
227 if "photo" in info:
228 info["media_type"] = "photo"
229 if "video/" in info.get("mime_type", ""):
230 info["media_type"] = "video_file"
231 mtype = info.get("media_type", "text")
232 content = parse_text(info.get("text", []))
233 urls = parse_urls(info.get("text_entities", []))
234 # fine-grained check requires key: ["cid", "mtype", "text", "entity_urls"]
235 if not fine_grained_check({"cid": data["id"], "mtype": mtype, "text": content, "entity_urls": urls}):
236 continue
237 dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
238 uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
239 user = info["from"] or info["from_id"].removeprefix("user").removeprefix("channel")
240 if user == data["name"] and data["type"] in ["public_channel", "private_channel"]: # user is not shown
241 user = ""
242 uid = 1
243
244 records = {
245 "mid": info["id"],
246 "mtype": mtypes.get(mtype, mtype),
247 "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
248 "fullname": user,
249 "content": content,
250 "filename": info.get("file_name", ""),
251 "urls": urls,
252 "reply": info.get("reply_to_message_id"),
253 "mime": info.get("mime_type", ""),
254 "user": user.replace(" ", ""),
255 "handle": "", # TODO: parse handle
256 "uid": uid,
257 "gid": info.get("media_group_id", 0),
258 "segmented": " ".join(cutter.cutword(content)),
259 }
260 # logger.debug(f"Syncing message {table_name}: {info['id']}")
261 tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
262 if len(tasks) == concurrency:
263 res = await asyncio.gather(*tasks, return_exceptions=True)
264 num_success = sum(glom(res, "*.success"))
265 logger.success(f"Synced {num_success} messages to D1")
266 tasks = []
267
268 if tasks:
269 res = await asyncio.gather(*tasks, return_exceptions=True)
270 num_success = sum(glom(res, "*.success"))
271 logger.success(f"Synced {num_success} messages to D1")
272
273
274async def get_d1_userinfo(uid: int, cid: int) -> dict:
275 """Get user info from table `userinfo`.
276
277 Returns:
278 uid, full_name, handle
279 """
280 # create table
281 await create_d1_table(
282 "userinfo",
283 USER_COLUMNS,
284 db_name=HISTORY.D1_DATABASE,
285 idx_cols=USER_INDEXES,
286 idx_prefix="idx_userinfo_",
287 silent=True,
288 )
289 resp = await query_d1(f"SELECT * FROM userinfo WHERE uid={uid} AND cid={cid};", db_name=HISTORY.D1_DATABASE, silent=True)
290 return glom(resp, "result.0.results.0", default={})
291
292
293async def save_userinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
294 """Save user info to table `userinfo`.
295
296 Args:
297 minfo (dict): parsed message info.
298
299 Returns:
300 uid, full_name, handle, tags
301 """
302 uid = int(minfo["uid"])
303 cid = int(slim_cid(minfo["cid"]))
304 if uid == 1: # default user (user is unknown)
305 return {}
306 # Get user info from turso and save it to cache
307 if not (cached := cache.get(f"d1-user-{uid}-{cid}")):
308 cached = await get_d1_userinfo(uid, cid)
309 cache.set(f"d1-user-{uid}-{cid}", cached, ttl=0)
310
311 ctitle = minfo["ctitle"] or minfo["full_name"]
312 # if in private chats, we use the opponent's name as chat title
313 if minfo["ctype"] in ["BOT", "PRIVATE"]:
314 chat = await get_chat(client, minfo["cid"])
315 if chat.id != 0:
316 ctitle = parse_chat(chat)["ctitle"]
317
318 primary_key = uid if uid == cid else abs(uid - cid)
319 records = {
320 "ctitle": ctitle,
321 "full_name": minfo["full_name"],
322 "handle": minfo["handle"],
323 "tags": cached.get("tags", ""),
324 "name": minfo["full_name"].replace(" ", ""),
325 "uid": uid,
326 "cid": cid,
327 "id": primary_key,
328 }
329 if cached != records:
330 logger.info(f"Save user info: {records}")
331 cache.set(f"d1-user-{uid}-{cid}", records, ttl=0)
332 await query_d1(**insert_d1("userinfo", records, update_on_conflict="id"), db_name=HISTORY.D1_DATABASE, silent=True)
333 return records
334
335
336async def get_d1_chatinfo(cid: str | int) -> dict:
337 """Get chat info from table `chatinfo`.
338
339 Returns:
340 cid, ctype, ctitle, chandle
341 """
342 # create table
343 await create_d1_table("chatinfo", CHAT_COLUMNS, db_name=HISTORY.D1_DATABASE, silent=True)
344 resp = await query_d1(f"SELECT * FROM chatinfo WHERE cid='{slim_cid(cid)}' OR chandle='{cid}';", db_name=HISTORY.D1_DATABASE, silent=True)
345 return glom(resp, "result.0.results.0", default={})
346
347
348async def save_chatinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
349 """Save chat info to table `chatinfo`.
350
351 Args:
352 minfo (dict): parsed message info.
353
354 Returns:
355 cid, ctype, ctitle, chandle, tablename, tags
356 """
357 cid = slim_cid(minfo["cid"])
358 if str(cid) == "0":
359 return {}
360 # Get chat info from turso and save it to cache
361 if not (cached := cache.get(f"d1-chat-{cid}")):
362 cached = await get_d1_chatinfo(cid)
363 cache.set(f"d1-chat-{cid}", cached, ttl=0)
364
365 ctitle = minfo["ctitle"] or minfo["full_name"]
366 # if in private chats, we use the opponent's name as chat title
367 if minfo["ctype"] in ["BOT", "PRIVATE"]:
368 chat = await get_chat(client, minfo["cid"])
369 if chat.id != 0:
370 ctitle = parse_chat(chat)["ctitle"]
371
372 records = {
373 "cid": int(cid),
374 "ctype": minfo["ctype"],
375 "ctitle": ctitle,
376 "chandle": minfo["chandle"],
377 "tablename": cached.get("tablename", "") or f"{cid}-{ctitle}",
378 "tags": cached.get("tags", ""),
379 }
380 # create table for this chat
381 await create_d1_table(
382 records["tablename"],
383 MSG_COLUMNS,
384 idx_cols=MSG_INDEXES,
385 idx_prefix=f"idx_{cid}_",
386 fts_on_col="mid",
387 fts_name=cid,
388 db_name=HISTORY.D1_DATABASE,
389 silent=True,
390 )
391 if cached != records:
392 logger.info(f"Save chat info: {records}")
393 cache.set(f"d1-chat-{cid}", records, ttl=0)
394 await query_d1(**insert_d1("chatinfo", records, update_on_conflict="cid"), db_name=HISTORY.D1_DATABASE, silent=True)
395 return records