main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import json
  5import os
  6from datetime import datetime, timedelta
  7from pathlib import Path
  8from typing import Literal
  9from zoneinfo import ZoneInfo
 10
 11from glom import Coalesce, glom
 12from loguru import logger
 13from pyrogram.client import Client
 14from pyrogram.types import Message
 15
 16from config import DOWNLOAD_DIR, HISTORY, TZ, cache, cutter
 17from database.d1 import create_d1_table, insert_d1, query_d1
 18from history.utils import CHAT_COLUMNS, MSG_COLUMNS, MSG_INDEXES, USER_COLUMNS, USER_INDEXES, can_delete_history, check_save_history, fine_grained_check, get_chat
 19from messages.parser import parse_chat, parse_msg
 20from utils import i_am_bot, nowdt, slim_cid, to_int, true
 21
 22
 23async def sync_history_to_d1(client: Client, message: Message) -> None:
 24    """Sync received messages to D1 database.
 25
 26    1. save the user info to table `userinfo`
 27    2. save the chat info to table `chatinfo`
 28    3. save the message to table `{cid}-{ctitle}`
 29    """
 30    if not HISTORY.D1_ENABLE:
 31        return
 32    if isinstance(message, list):  # this is deleted messages
 33        await delete_messages(message)
 34        return
 35    info = parse_msg(message, silent=True, use_cache=False)
 36    if not check_save_history(info["ctype"], info["cid"]) or not fine_grained_check(info) or message.service:
 37        return
 38    await save_userinfo_to_d1(client, info)
 39    chatinfo = await save_chatinfo_to_d1(client, info)
 40    records = {
 41        "mid": info["mid"],
 42        "mtype": info["mtype"],
 43        "time": info["time"],
 44        "fullname": info["full_name"],
 45        "content": message.content,  # text or edited text
 46        "filename": info["file_name"],
 47        "urls": "\n\n".join(info["entity_urls"]),
 48        "reply": message.reply_to_message_id,
 49        "mime": info["mime_type"],
 50        "user": info["full_name"].replace(" ", ""),
 51        "handle": info["handle"],
 52        "uid": info["uid"],
 53        "gid": info["media_group_id"],
 54        "segmented": " ".join(cutter.cutword(message.content)),
 55    }
 56    await query_d1(**insert_d1(chatinfo["tablename"], records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True)
 57
 58
 59async def delete_messages(messages: Message | list[Message]) -> None:
 60    """Delete messages from D1 database."""
 61    if not isinstance(messages, list):
 62        messages = [messages]
 63    for message in messages:
 64        cid = glom(message, "chat.id", default=0) or 0
 65        mid = glom(message, "id", default=0) or 0
 66        ctype = glom(message, "chat.type.name", default="") or ""
 67        if not check_save_history(ctype, cid) or message.service:
 68            return
 69        chatinfo = await get_d1_chatinfo(cid)
 70        if not chatinfo:
 71            continue
 72        tablename = chatinfo["tablename"]
 73        resp = await query_d1(f"SELECT * FROM '{tablename}' WHERE mid={mid};", db_name=HISTORY.D1_DATABASE, silent=True)
 74        uid = glom(resp, "result.0.results.0.uid", default=0) or 0
 75        if not uid:
 76            continue
 77        if can_delete_history(cid, uid):
 78            logger.warning(f"Delete message Chat={cid}, ID={mid}: {glom(resp, 'result.0.results.0')}")
 79            await query_d1(f"DELETE FROM '{tablename}' WHERE mid={mid};", db_name=HISTORY.D1_DATABASE, silent=True)
 80
 81
 82async def backup_chat_history_to_d1(
 83    client: Client,
 84    chat_id: str | int,
 85    hours: float = HISTORY.BACKUP_CHATS_HOURS,
 86    *,
 87    start_from: Literal["latest", "oldest"] = "latest",
 88    max_sync: float = float("inf"),
 89) -> None:
 90    """Backup chat history to D1 database.
 91
 92    If start_from is "oldest", find the minimum message id of this chat, then use this mid as `offset_id` to retrieve messages.
 93    """
 94    if not HISTORY.D1_ENABLE:
 95        return
 96    if await i_am_bot(client):
 97        return
 98    chat = await get_chat(client, to_int(chat_id))
 99    if chat.id == 0:  # chat is not accessible
100        return
101    chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
102    if true(os.getenv(f"HISTORY_IGNORE_{chatinfo['cid']}")):
103        return
104    table_name = chatinfo["tablename"]
105    now = nowdt(TZ)
106    begin_dt = now - timedelta(hours=hours)
107    begin_time = begin_dt.strftime("%Y-%m-%d %H:%M:%S")
108    if start_from == "oldest":
109        sql = f'SELECT mid FROM "{table_name}" ORDER BY mid ASC LIMIT 1'
110        resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
111        offset_id = glom(resp, "result.0.results.0.mid", default=1)
112        saved_mids = {int(offset_id)}
113    else:
114        # find message ids in this time range
115        end_time = now.strftime("%Y-%m-%d %H:%M:%S")
116        sql = f'SELECT mid FROM "{table_name}" WHERE time >= "{begin_time}" AND time <= "{end_time}"'
117        resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
118        saved_mids = glom(resp, "result.0.results.*.mid", default=[])
119        saved_mids = {int(x) for x in saved_mids}
120        offset_id = 0  # retrieve from latest message
121    logger.info(f"Found {len(saved_mids)} messages of {table_name} in D1, Time >= {begin_time}")
122    concurrency = 100
123    num_sync = 0
124    tasks = []
125    real_cid = chatinfo["chandle"] or (int(chatinfo["cid"]) if chatinfo["ctype"] in ["BOT", "PRIVATE"] else int(f"-100{chatinfo['cid']}"))
126    async for message in client.get_chat_history(real_cid, max_id=offset_id):  # type: ignore
127        if not isinstance(message, Message) or message.empty or message.service or message.id in saved_mids:
128            continue
129        info = parse_msg(message, silent=True, use_cache=False)
130        if info["time"] < begin_time:
131            break
132        if num_sync >= max_sync:
133            break
134        if not fine_grained_check(info):
135            continue
136        num_sync += 1
137        records = {
138            "mid": info["mid"],
139            "mtype": info["mtype"],
140            "time": info["time"],
141            "fullname": info["full_name"],
142            "content": message.content,
143            "filename": info["file_name"],
144            "urls": "\n\n".join(info["entity_urls"]),
145            "reply": message.reply_to_message_id,
146            "mime": info["mime_type"],
147            "user": info["full_name"].replace(" ", ""),
148            "handle": info["handle"],
149            "uid": info["uid"],
150            "gid": info["media_group_id"],
151            "segmented": " ".join(cutter.cutword(info["text"])),
152        }
153        logger.trace(f"Syncing {table_name}: {info['mid']} - {info['time']}")
154        tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
155        if len(tasks) == concurrency:
156            res = await asyncio.gather(*tasks, return_exceptions=True)
157            num_success = sum(glom(res, "*.success"))
158            logger.success(f"Synced {num_success} messages to D1")
159            tasks = []
160
161    if tasks:
162        res = await asyncio.gather(*tasks, return_exceptions=True)
163        num_success = sum(glom(res, "*.success"))
164        logger.success(f"Synced {num_success} messages to D1")
165
166
167async def upload_exported_history_to_d1(client: Client, path: str | Path | None = None) -> None:
168    if not HISTORY.D1_ENABLE:
169        return
170    if path is None:
171        path = Path(DOWNLOAD_DIR) / "result.json"
172    path = Path(path)
173    if not path.is_file():
174        return
175
176    def parse_text(texts: list) -> str:
177        if isinstance(texts, str):
178            return texts
179        text = ""
180        for x in texts:
181            text += x if isinstance(x, str) else x.get("text", "")
182        return text
183
184    def parse_urls(entities: list) -> str:
185        urls = [glom(x, Coalesce("href", "text")) for x in entities if x["type"] in {"link", "text_link"}]
186        return "\n\n".join(urls)
187
188    with path.open("r") as f:  # noqa: ASYNC230
189        data = json.load(f)
190        logger.info(f"Found {len(data['messages'])} messages in json file")
191        """Since the exported history does not has media_group_id,
192        So we first process all messages and  add media_group_id for it.
193        If two consecutive messages have the same `from_id` and `date_unixtime`,
194        and the message type is photo or video, these messages will be considered as a media group.
195        """
196        last_msg = {}
197        for idx, msg in enumerate(data["messages"]):
198            if all(msg.get(key) == last_msg.get(key) for key in ["from_id", "date_unixtime"]) and any(key in msg for key in ["photo", "thumbnail"]):
199                data["messages"][idx - 1]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
200                data["messages"][idx]["media_group_id"] = glom(data["messages"][idx - 1], Coalesce("media_group_id", "id"))
201            last_msg = msg
202
203    mtypes = {
204        "audio_file": "audio",
205        "voice_message": "voice",
206        "video_message": "video",
207        "video_file": "video",
208    }
209    chat_id = data["id"]
210    chatinfo = await get_d1_chatinfo(chat_id)
211    if not chatinfo:  # this chat is never synced
212        chat = await get_chat(client, int(chat_id))
213        chatinfo = await save_chatinfo_to_d1(client, parse_chat(chat))
214    table_name = chatinfo["tablename"]
215    # find all message_ids
216    resp = await query_d1(f'SELECT mid FROM "{table_name}";', db_name=HISTORY.D1_DATABASE, silent=True)
217    saved_ids = glom(resp, "result.0.results.*.mid", default=[])
218    saved_ids = {int(x) for x in saved_ids}
219    concurrency = 100
220    tasks = []
221    for info in [msg for msg in data["messages"] if msg["id"] not in saved_ids]:  # type: ignore
222        if info["type"] != "message":
223            continue
224        if info["date_unixtime"] == "0":
225            continue
226        if "media_type" not in info:  # guess mtype
227            if "photo" in info:
228                info["media_type"] = "photo"
229            if "video/" in info.get("mime_type", ""):
230                info["media_type"] = "video_file"
231        mtype = info.get("media_type", "text")
232        content = parse_text(info.get("text", []))
233        urls = parse_urls(info.get("text_entities", []))
234        # fine-grained check requires key: ["cid", "mtype", "text", "entity_urls"]
235        if not fine_grained_check({"cid": data["id"], "mtype": mtype, "text": content, "entity_urls": urls}):
236            continue
237        dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
238        uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
239        user = info["from"] or info["from_id"].removeprefix("user").removeprefix("channel")
240        if user == data["name"] and data["type"] in ["public_channel", "private_channel"]:  # user is not shown
241            user = ""
242            uid = 1
243
244        records = {
245            "mid": info["id"],
246            "mtype": mtypes.get(mtype, mtype),
247            "time": dt.strftime("%Y-%m-%d %H:%M:%S"),
248            "fullname": user,
249            "content": content,
250            "filename": info.get("file_name", ""),
251            "urls": urls,
252            "reply": info.get("reply_to_message_id"),
253            "mime": info.get("mime_type", ""),
254            "user": user.replace(" ", ""),
255            "handle": "",  # TODO: parse handle
256            "uid": uid,
257            "gid": info.get("media_group_id", 0),
258            "segmented": " ".join(cutter.cutword(content)),
259        }
260        # logger.debug(f"Syncing message {table_name}: {info['id']}")
261        tasks.append(query_d1(**insert_d1(table_name, records, update_on_conflict="mid"), db_name=HISTORY.D1_DATABASE, silent=True))
262        if len(tasks) == concurrency:
263            res = await asyncio.gather(*tasks, return_exceptions=True)
264            num_success = sum(glom(res, "*.success"))
265            logger.success(f"Synced {num_success} messages to D1")
266            tasks = []
267
268    if tasks:
269        res = await asyncio.gather(*tasks, return_exceptions=True)
270        num_success = sum(glom(res, "*.success"))
271        logger.success(f"Synced {num_success} messages to D1")
272
273
274async def get_d1_userinfo(uid: int, cid: int) -> dict:
275    """Get user info from table `userinfo`.
276
277    Returns:
278        uid, full_name, handle
279    """
280    # create table
281    await create_d1_table(
282        "userinfo",
283        USER_COLUMNS,
284        db_name=HISTORY.D1_DATABASE,
285        idx_cols=USER_INDEXES,
286        idx_prefix="idx_userinfo_",
287        silent=True,
288    )
289    resp = await query_d1(f"SELECT * FROM userinfo WHERE uid={uid} AND cid={cid};", db_name=HISTORY.D1_DATABASE, silent=True)
290    return glom(resp, "result.0.results.0", default={})
291
292
293async def save_userinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
294    """Save user info to table `userinfo`.
295
296    Args:
297        minfo (dict): parsed message info.
298
299    Returns:
300        uid, full_name, handle, tags
301    """
302    uid = int(minfo["uid"])
303    cid = int(slim_cid(minfo["cid"]))
304    if uid == 1:  # default user (user is unknown)
305        return {}
306    # Get user info from turso and save it to cache
307    if not (cached := cache.get(f"d1-user-{uid}-{cid}")):
308        cached = await get_d1_userinfo(uid, cid)
309        cache.set(f"d1-user-{uid}-{cid}", cached, ttl=0)
310
311    ctitle = minfo["ctitle"] or minfo["full_name"]
312    # if in private chats, we use the opponent's name as chat title
313    if minfo["ctype"] in ["BOT", "PRIVATE"]:
314        chat = await get_chat(client, minfo["cid"])
315        if chat.id != 0:
316            ctitle = parse_chat(chat)["ctitle"]
317
318    primary_key = uid if uid == cid else abs(uid - cid)
319    records = {
320        "ctitle": ctitle,
321        "full_name": minfo["full_name"],
322        "handle": minfo["handle"],
323        "tags": cached.get("tags", ""),
324        "name": minfo["full_name"].replace(" ", ""),
325        "uid": uid,
326        "cid": cid,
327        "id": primary_key,
328    }
329    if cached != records:
330        logger.info(f"Save user info: {records}")
331        cache.set(f"d1-user-{uid}-{cid}", records, ttl=0)
332        await query_d1(**insert_d1("userinfo", records, update_on_conflict="id"), db_name=HISTORY.D1_DATABASE, silent=True)
333    return records
334
335
336async def get_d1_chatinfo(cid: str | int) -> dict:
337    """Get chat info from table `chatinfo`.
338
339    Returns:
340        cid, ctype, ctitle, chandle
341    """
342    # create table
343    await create_d1_table("chatinfo", CHAT_COLUMNS, db_name=HISTORY.D1_DATABASE, silent=True)
344    resp = await query_d1(f"SELECT * FROM chatinfo WHERE cid='{slim_cid(cid)}' OR chandle='{cid}';", db_name=HISTORY.D1_DATABASE, silent=True)
345    return glom(resp, "result.0.results.0", default={})
346
347
348async def save_chatinfo_to_d1(client: Client, minfo: dict) -> dict[str, str]:
349    """Save chat info to table `chatinfo`.
350
351    Args:
352        minfo (dict): parsed message info.
353
354    Returns:
355        cid, ctype, ctitle, chandle, tablename, tags
356    """
357    cid = slim_cid(minfo["cid"])
358    if str(cid) == "0":
359        return {}
360    # Get chat info from turso and save it to cache
361    if not (cached := cache.get(f"d1-chat-{cid}")):
362        cached = await get_d1_chatinfo(cid)
363        cache.set(f"d1-chat-{cid}", cached, ttl=0)
364
365    ctitle = minfo["ctitle"] or minfo["full_name"]
366    # if in private chats, we use the opponent's name as chat title
367    if minfo["ctype"] in ["BOT", "PRIVATE"]:
368        chat = await get_chat(client, minfo["cid"])
369        if chat.id != 0:
370            ctitle = parse_chat(chat)["ctitle"]
371
372    records = {
373        "cid": int(cid),
374        "ctype": minfo["ctype"],
375        "ctitle": ctitle,
376        "chandle": minfo["chandle"],
377        "tablename": cached.get("tablename", "") or f"{cid}-{ctitle}",
378        "tags": cached.get("tags", ""),
379    }
380    # create table for this chat
381    await create_d1_table(
382        records["tablename"],
383        MSG_COLUMNS,
384        idx_cols=MSG_INDEXES,
385        idx_prefix=f"idx_{cid}_",
386        fts_on_col="mid",
387        fts_name=cid,
388        db_name=HISTORY.D1_DATABASE,
389        silent=True,
390    )
391    if cached != records:
392        logger.info(f"Save chat info: {records}")
393        cache.set(f"d1-chat-{cid}", records, ttl=0)
394        await query_d1(**insert_d1("chatinfo", records, update_on_conflict="cid"), db_name=HISTORY.D1_DATABASE, silent=True)
395    return records