Commit 4b22219
Changed files (1)
src
history
src/history/sync.py
@@ -6,14 +6,14 @@ from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
-from glom import glom
+from glom import flatten, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.errors import PeerIdInvalid
from pyrogram.types import Chat, Message
from config import DOWNLOAD_DIR, HISTORY, TZ, cache
-from database import create_d1_table, query_d1
+from database import create_d1_database, create_d1_table, query_d1
from messages.parser import parse_msg
from utils import i_am_bot
@@ -26,15 +26,7 @@ async def save_history_to_d1(client: Client, message: Message) -> None:
if not HISTORY.D1_ENABLE:
return
info = parse_msg(message, silent=True)
- cid = str(info["cid"]).removeprefix("-100")
- chat_title = info["ctitle"] or info["full_name"]
- if info["ctype"] in ["BOT", "PRIVATE"]: # for private chats, we use peer side name as chat_title
- chat = await chat_info(client, info["cid"])
- first_name = chat.first_name or ""
- last_name = chat.last_name or ""
- chat_title = first_name + last_name
- table_name = f"{cid}-{chat_title}".replace(" ", "")
- await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+ table_name = await get_table_name(client, info["cid"])
records = {
"mid": info["mid"],
"day": info["time"].split(" ")[0],
@@ -60,10 +52,7 @@ async def sync_chat_history_to_d1(client: Client, chat_id: str | int) -> None:
if await i_am_bot(client):
return
chat = await chat_info(client, int(chat_id))
- cid = str(chat.id).removeprefix("-100")
- chat_name = chat.title or chat.full_name
- table_name = f"{cid}-{chat_name}".replace(" ", "")
- await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+ table_name = await get_table_name(client, chat_id)
# find last message id
sql = f'SELECT mid FROM "{table_name}" ORDER BY mid DESC LIMIT 1;'
resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
@@ -90,7 +79,7 @@ async def sync_chat_history_to_d1(client: Client, chat_id: str | int) -> None:
"filename": info["file_name"],
"mime": info["mime_type"],
}
- logger.trace(f"Syncing message {chat_name}: {info['mid']}")
+ logger.trace(f"Syncing message {table_name}: {info['mid']}")
keys = ", ".join(records)
values = ", ".join(["?" for _ in range(len(records))])
updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
@@ -129,14 +118,19 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
with path.open("r") as f: # noqa: ASYNC230
data = json.load(f)
- mtypes = {"text": "text", "animation": "animation", "audio_file": "audio", "sticker": "sticker", "voice_message": "voice", "video_file": "video"}
-
- cid = str(data["id"]).removeprefix("-100")
- chat_name = data["name"]
- table_name = f"{cid}-{chat_name}".replace(" ", "")
- await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE, silent=True)
+ mtypes = {
+ "text": "text",
+ "photo": "photo",
+ "animation": "animation",
+ "audio_file": "audio",
+ "sticker": "sticker",
+ "voice_message": "voice",
+ "video_message": "video",
+ "video_file": "video",
+ }
+ table_name = await get_table_name(client, data["id"])
# find all message_ids
- sql = f'SELECT mid FROM "{table_name}";'
+ sql = f'SELECT mid FROM "{table_name}" ORDER BY mid;'
resp = await query_d1(sql, db_name=HISTORY.D1_DATABASE, silent=True)
saved_ids = glom(resp, "result.0.results.*.mid", default=[])
concurrency = 200
@@ -148,6 +142,12 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
continue
if info["date_unixtime"] == "0":
continue
+ if "media_type" not in info: # guess mtype
+ if "photo" in info:
+ info["media_type"] = "photo"
+ if "video/" in info.get("mime_type", ""):
+ info["media_type"] = "video_file"
+
dt = datetime.fromtimestamp(int(info["date_unixtime"]), tz=ZoneInfo(TZ))
uid = int(info["from_id"].removeprefix("user").removeprefix("channel"))
user = info["from"]
@@ -168,7 +168,7 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
"filename": info.get("file_name", ""),
"mime": info.get("mime_type", ""),
}
- logger.trace(f"Syncing message {chat_name}: {info['id']}")
+ logger.trace(f"Syncing message {table_name}: {info['id']}")
keys = ", ".join(records)
values = ", ".join(["?" for _ in range(len(records))])
updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != "mid"])
@@ -189,16 +189,45 @@ async def sync_export_history_to_d1(client: Client, path: str | Path | None = No
tasks = []
-async def chat_info(client: Client, uid: int) -> Chat:
- if cache.get(f"chat-info-{uid}"):
- return cache.get(f"chat-info-{uid}")
- if uid == 1:
+async def chat_info(client: Client, chat_id: int) -> Chat:
+ if cache.get(f"chat-info-{chat_id}"):
+ return cache.get(f"chat-info-{chat_id}")
+ if chat_id == 1:
return Chat(id=1)
try:
- chat = await client.get_chat(int(uid))
+ logger.debug(f"Getting chat info for {chat_id}")
+ chat = await client.get_chat(int(chat_id))
except PeerIdInvalid:
- return await chat_info(client, int(f"-100{uid}"))
+ return await chat_info(client, int(f"-100{chat_id}"))
except Exception:
chat = Chat(id=1)
- cache.set(f"chat-info-{uid}", chat, ttl=3600) # cache for 1 hour
+ cache.set(f"chat-info-{chat_id}", chat, ttl=3600) # cache for 1 hour
return chat
+
+
+async def get_table_name(client: Client, chat_id: str | int) -> str:
+ """Get D1 table name by chat id."""
+ if cache.get(f"tablename-{chat_id}"):
+ return cache.get(f"tablename-{chat_id}")
+ # get a default table name
+ chat = await chat_info(client, int(chat_id))
+ first_name = chat.first_name or ""
+ last_name = chat.last_name or ""
+ full_name = first_name + last_name
+ chat_title = full_name or chat.title or ""
+ slim_cid = str(chat_id).removeprefix("-100")
+ default_name = f"{slim_cid}-{chat_title}".replace(" ", "")
+
+ # get D1 tables
+ database_id = await create_d1_database(name=HISTORY.D1_DATABASE)
+ if not database_id:
+ await create_d1_table(default_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+ return default_name
+ sql = "SELECT name FROM sqlite_master WHERE type='table';"
+ resp = await query_d1(sql, database_id, silent=True)
+ table_names = flatten(glom(resp, "result.*.results.*.name", default=[]))
+ table_mapping = {x.split("-")[0]: x for x in table_names}
+ table_name = table_mapping.get(str(slim_cid), default_name)
+ cache.set(f"tablename-{chat_id}", table_name, ttl=0)
+ await create_d1_table(table_name, DB_COLUMNS, HISTORY.D1_DATABASE)
+ return table_name