Commit 4b22219

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-04 11:58:13
fix(history): persist table name based on chat id
1 parent 5d4729c
Changed files (1)
src
history
src/history/sync.py
@@ -6,14 +6,14 @@ from datetime import datetime
 from pathlib import Path
 from zoneinfo import ZoneInfo
 
-from glom import glom
+from glom import flatten, glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.errors import PeerIdInvalid
 from pyrogram.types import Chat, Message
 
 from config import DOWNLOAD_DIR, HISTORY, TZ, cache
-from database import create_d1_table, query_d1
+from database import create_d1_database, create_d1_table, query_d1
 from messages.parser import parse_msg
 from utils import i_am_bot
 
@@ -26,15 +26,7 @@ async def save_history_to_d1(client: Client, message: Message) -> None:
     if not HISTORY.D1_ENABLE:
         return
     info = parse_msg(message, silent=True)
-    cid = str(info["cid"]).removeprefix("-100")
-    chat_title = info["ctitle"] or info["full_name"]
-    if info["ctype"] in ["BOT", "PRIVATE"]:  # for private chats, we use peer side name as chat_title
-        chat = await chat_info(client, info["cid"])
-        first_name = chat.first_name or ""
-        last_name = chat.last_name or ""
-        chat_title = first_name + last_name
-    table_name = f"{cid}-{chat_title}".replace(" ", "")
-    await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+    table_name = await get_table_name(client, info["cid"])
     records = {
         "mid": info["mid"],
         "day": info["time"].split(" ")[0],
@@ -60,10 +52,7 @@ async def sync_chat_history_to_d1(client: Client, chat_id: str | int) -> None:
     if await i_am_bot(client):
         return
     chat = await chat_info(client, int(chat_id))
-    cid = str(chat.id).removeprefix("-100")
-    chat_name = chat.title or chat.full_name
-    table_name = f"{cid}-{chat_name}".replace(" ", "")
-    await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+    table_name = await get_table_name(client, chat_id)
     # find last message id
     sql = f'SELECT mid FROM "{table_name}" ORDER BY mid DESC LIMIT 1;'
     resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
@@ -90,7 +79,7 @@ async def sync_chat_history_to_d1(client: Client, chat_id: str | int) -> None:
             "filename": info["file_name"],
             "mime": info["mime_type"],
         }
-        logger.trace(f"Syncing message {chat_name}: {info['mid']}")
+        logger.trace(f"Syncing message {table_name}: {info['mid']}")
         keys = ", ".join(records)
         values = ", ".join(["?" for _ in range(len(records))])
         updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
@@ -129,14 +118,19 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
 
     with path.open("r") as f:  # noqa: ASYNC230
         data = json.load(f)
-    mtypes = {"text": "text", "animation": "animation", "audio_file": "audio", "sticker": "sticker", "voice_message": "voice", "video_file": "video"}
-
-    cid = str(data["id"]).removeprefix("-100")
-    chat_name = data["name"]
-    table_name = f"{cid}-{chat_name}".replace(" ", "")
-    await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+    mtypes = {
+        "text": "text",
+        "photo": "photo",
+        "animation": "animation",
+        "audio_file": "audio",
+        "sticker": "sticker",
+        "voice_message": "voice",
+        "video_message": "video",
+        "video_file": "video",
+    }
+    table_name = await get_table_name(client, data["id"])
     # find all message_ids
-    sql = f'SELECT mid FROM "{table_name}";'
+    sql = f'SELECT mid FROM "{table_name}" ORDER BY mid;'
     resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
     saved_ids = glom(resp, "result.0.results.*.mid", default=[])
     concurrency = 200
@@ -148,6 +142,12 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
             continue
         if info["date_unixtime"] == "0":
             continue
+        if "media_type" not in info:  # guess mtype
+            if "photo" in info:
+                info["media_type"] = "photo"
+            if "video/" in info.get("mime_type", ""):
+                info["media_type"] = "video_file"
+
         dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
         uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
         user = info["from"]
@@ -168,7 +168,7 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
             "filename": info.get("file_name", ""),
             "mime": info.get("mime_type", ""),
         }
-        logger.trace(f"Syncing message {chat_name}: {info['id']}")
+        logger.trace(f"Syncing message {table_name}: {info['id']}")
         keys = ", ".join(records)
         values = ", ".join(["?" for _ in range(len(records))])
         updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
@@ -189,16 +189,45 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
         tasks = []
 
 
-async def chat_info(client: Client, uid: int) -> Chat:
-    if cache.get(f"chat-info-{uid}"):
-        return cache.get(f"chat-info-{uid}")
-    if uid == 1:
+async def chat_info(client: Client, chat_id: int) -> Chat:
+    if cache.get(f"chat-info-{chat_id}"):
+        return cache.get(f"chat-info-{chat_id}")
+    if chat_id == 1:
         return Chat(id=1)
     try:
-        chat = await client.get_chat(int(uid))
+        logger.debug(f"Getting chat info for {chat_id}")
+        chat = await client.get_chat(int(chat_id))
     except PeerIdInvalid:
-        return await chat_info(client, int(f"-100{uid}"))
+        return await chat_info(client, int(f"-100{chat_id}"))
     except Exception:
         chat = Chat(id=1)
-    cache.set(f"chat-info-{uid}", chat, ttl=3600)  # cache for 1 hour
+    cache.set(f"chat-info-{chat_id}", chat, ttl=3600)  # cache for 1 hour
     return chat
+
+
+async def get_table_name(client: Client, chat_id: str | int) -> str:
+    """Get D1 table name by chat id."""
+    if cache.get(f"tablename-{chat_id}"):
+        return cache.get(f"tablename-{chat_id}")
+    # get a default table name
+    chat = await chat_info(client, int(chat_id))
+    first_name = chat.first_name or ""
+    last_name = chat.last_name or ""
+    full_name = first_name + last_name
+    chat_title = full_name or chat.title or ""
+    slim_cid = str(chat_id).removeprefix("-100")
+    default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
+
+    # get D1 tables
+    database_id = await create_d1_database(name=HISTORY.D1_DATABASE)
+    if not database_id:
+        await create_d1_table(default_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+        return default_name
+    sql = "SELECT name FROM sqlite_master WHERE type='table';"
+    resp = await query_d1(sql, database_id, silent=True)
+    table_names = flatten(glom(resp, "result.*.results.*.name", default=[]))
+    table_mapping = {x.split("-")[0]: x for x in table_names}
+    table_name = table_mapping.get(str(slim_cid), default_name)
+    cache.set(f"tablename-{chat_id}", table_name, ttl=0)
+    await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+    return table_name