Commit c52a618

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-28 10:14:06
feat(database): add Cloudflare D1 support
1 parent cc9db51
Changed files (4)
src/others/danmu.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import re
 from collections import defaultdict
 from datetime import UTC, datetime
@@ -7,17 +8,19 @@ from decimal import Decimal
 from io import BytesIO
 from zoneinfo import ZoneInfo
 
+from glom import flatten, glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import DANMU, PREFIX, TEXT_LENGTH, TZ
+from config import DANMU, PREFIX, TEXT_LENGTH, TZ, cache
+from database import create_cf_d1_table, query_cf_d1
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import blockquote, equal_prefix, startswith_prefix
 from networking import hx_req
-from utils import number
+from utils import nowdt, number
 
 HELP = f"""📖**查询弹幕记录**
 使用说明:
@@ -99,7 +102,7 @@ async def query_danmu(client: Client, message: Message, *, full_history: bool =
     logger.debug(f"Query: {payload}")
     status_msg = (await send2tg(client, message, texts=caption, **kwargs))[0]
     kwargs["progress"] = status_msg
-    resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
+    resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, proxy=DANMU.PROXY, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
     count = resp["count"]
     if count == 0:
         await modify_progress(message=status_msg, text=caption + "\n⚠️未找到匹配弹幕", force_update=True, **kwargs)
@@ -116,7 +119,7 @@ async def query_danmu(client: Client, message: Message, *, full_history: bool =
             page += 1
             payload["page"] = page
             logger.debug(f"Query: {payload}")
-            resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
+            resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, proxy=DANMU.PROXY, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
             parsed = parse_danmu(resp["data"], super_chats, show_name=show_name)
             danmu += parsed["danmu"]
             processed += parsed["num_messages"]
@@ -148,3 +151,55 @@ def parse_danmu(data: list[dict], super_chats: defaultdict, *, show_name: bool =
                 super_chats[currency] += Decimal(matched.group(2))
         msg += f"\n{dt:%m-%d %H:%M:%S}|{danmu['authorName']}{sc_amount}: {danmu['message']}" if show_name else f"\n{dt:%m-%d %H:%M:%S}{sc_amount}: {danmu['message']}"
     return {"danmu": msg.strip(), "num_messages": len(data)}
+
+
+@cache.memoize(ttl=3600)
+async def sync_danmu_to_d1() -> None:
+    """Deprecated, D1 only allow 5M reads, 100K writes."""
+    # ruff: noqa: S608
+    concurrency = 200
+    now = nowdt(TZ)
+
+    async def batch_sync(danmu_list: list[dict], saved_items: list[str]) -> int:
+        sc_pattern = re.compile(r"^([A-Z]{2,}) (\d+(?:\.\d+)?)$")
+        db_columns = "id INTEGER PRIMARY KEY, time TEXT, uid INTEGER, user TEXT, text TEXT, sc_amt REAL NULL, sc_ccy TEXT NULL"
+        tasks = []
+        for danmu in danmu_list:
+            dt = datetime.fromtimestamp(danmu["timestamp"] / 1000000, tz=UTC).astimezone(ZoneInfo(TZ))
+            table_name = dt.year
+            if f"{dt:%Y-%m-%d %H:%M:%S}{danmu['authorId']}" in saved_items:
+                continue
+            await create_cf_d1_table(table_name, db_columns, DANMU.D1_DATABASE)
+            columns = ["time", "uid", "user", "text"]
+            params = [f"{dt:%Y-%m-%d %H:%M:%S}", danmu["authorId"], danmu["authorName"], danmu["message"]]
+            if (amt := danmu.get("scAmount")) and (matched := sc_pattern.fullmatch(amt)):
+                sc_ccy = matched.group(1)
+                sc_amt = matched.group(2)
+                columns.extend(["sc_amt", "sc_ccy"])
+                params.extend([float(sc_amt), sc_ccy])
+            sql = f'INSERT INTO "{table_name}" ({", ".join(columns)}) VALUES ({", ".join(["?" for _ in range(len(columns))])});'
+            tasks.append(query_cf_d1(sql, db_name=DANMU.D1_DATABASE, params=params))
+        chunks = [tasks[i : i + concurrency] for i in range(0, len(tasks), concurrency)]
+        processed = 0
+        for chunk in chunks:
+            async with asyncio.Semaphore(concurrency):
+                results = await asyncio.gather(*chunk, return_exceptions=True)
+                processed += len(results)
+                if not any(glom(results, "*.success")):
+                    logger.error(f"Sync danmu to d1 failed: {results}")
+                    break
+        return processed
+
+    sql = f'SELECT time, uid FROM "{now.year}";'
+    resp = await query_cf_d1(sql, db_name=DANMU.D1_DATABASE)
+    saved_items = [f"{x['time']}{x['uid']}" for x in flatten(glom(resp, "result.*.results", default=[]))]
+    page = 1
+    payload = {"page": page, "limit": DANMU.NUM_PER_QUERY, "liveDate": f"{now:%Y-%m-%d}", "message": "", "authorName": ""}
+    resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, proxy=DANMU.PROXY, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
+    count = resp["count"]
+    processed = await batch_sync(resp["data"], saved_items)
+    while processed < count:
+        page += 1
+        payload["page"] = page
+        resp = await hx_req(DANMU.BASR_URL, "POST", data=payload, proxy=DANMU.PROXY, check_kv={"code": 0}, check_keys=["count", "data"], silent=True)
+        processed += await batch_sync(resp["data"], saved_items)
src/config.py
@@ -100,7 +100,9 @@ class API:
 
 class DANMU:
     BASR_URL = os.getenv("DANMU_BASR_URL", "")  # Custom API, No docs
+    PROXY = os.getenv("DANMU_PROXY", None)  # socks5://127.0.0.1:7890
     NUM_PER_QUERY = int(os.getenv("DANMU_NUM_PER_QUERY", "100"))  # Number of items per query
+    D1_DATABASE = os.getenv("DANMU_D1_DATABASE", "bennybot-danmu")
 
 
 class PROVIDER:  # default API provider
@@ -175,6 +177,7 @@ class DB:
     CF_R2_ACCESS_KEY_ID = os.getenv("CF_R2_ACCESS_KEY_ID", "")
     CF_R2_SECRET_ACCESS_KEY = os.getenv("CF_R2_SECRET_ACCESS_KEY", "")
     CF_R2_PUBLIC_URL = os.getenv("CF_R2_PUBLIC_URL", "")
+    CF_D1_ENABLED = os.getenv("CF_D1_ENABLED", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     ALIST_ENABLED = os.getenv("ALIST_ENABLED", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     ALIST_USERNAME = os.getenv("ALIST_USERNAME", "guest")
     ALIST_PASSWORD = os.getenv("ALIST_PASSWORD", "guest")
src/database.py
@@ -294,6 +294,52 @@ async def del_cf_r2(key: str):
     return
 
 
+@cache.memoize(ttl=0)
+async def create_cf_d1_database(name: str = "bennybot", primary_location_hint: str = "") -> str:
+    """Create D1 database and return DatabaseID."""
+    if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
+        return ""
+    api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/d1/database"
+    headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
+    payload = {"name": name}
+    # check if database exists
+    resp = await hx_req(api, method="GET", params=payload, headers=headers, check_kv={"success": True}, max_retry=0, silent=True)
+    if database_id := glom(resp, "result.0.uuid", default=""):
+        return database_id
+    if primary_location_hint:
+        payload |= {"primary_location_hint": primary_location_hint}
+    resp = await hx_req(api, method="POST", post_json=payload, headers=headers, check_kv={"success": True}, max_retry=0, silent=True)
+    return glom(resp, "result.uuid", default="")
+
+
+@cache.memoize(ttl=0)
+async def create_cf_d1_table(table_name: str | float, columns: str, db_name: str = "bennybot") -> None:
+    """Create D1 database and return DatabaseID."""
+    if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
+        return
+    database_id = await create_cf_d1_database(db_name)
+    if not database_id:
+        return
+    sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
+    await query_cf_d1(sql, database_id)
+    logger.success(f"Create Table {table_name} in D1 database {db_name}")
+
+
+async def query_cf_d1(sql: str, db_id: str | None = None, db_name: str = "bennybot", params: list[str] | None = None) -> dict:
+    """Query D1."""
+    if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
+        return {}
+    if db_id is None:
+        db_id = await create_cf_d1_database(db_name)
+    api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/d1/database/{db_id}/query"
+    headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
+    payload = {"sql": sql}
+    if params is not None:
+        payload |= {"params": params}
+    logger.trace(f"Query CF-D1: {payload}")
+    return await hx_req(api, "POST", post_json=payload, headers=headers, check_kv={"success": True}, max_retry=0, silent=True)
+
+
 async def list_alist() -> list[dict]:
     """List from Alist."""
     if not DB.ALIST_ENABLED:
@@ -456,5 +502,12 @@ if __name__ == "__main__":
     # asyncio.run(delete_alist("测试.mp3"))
     # asyncio.run(download_alist("test.py"))
     # asyncio.run(set_cf_r2("test2", metadata={"finished": "1"}))
-    asyncio.run(set_cf_r2("test2", data={"finished": "1"}, ttl=60))
-    asyncio.run(get_cf_r2("test2"))
+    # asyncio.run(set_cf_r2("test2", data={"finished": "1"}, ttl=60))
+    # asyncio.run(get_cf_r2("test2"))
+    columns = "id INTEGER PRIMARY KEY, time TEXT, uid INTEGER, user TEXT, text TEXT, sc_amt REAL NULL, sc_ccy TEXT NULL"
+    asyncio.run(create_cf_d1_table("2025", columns))
+
+    sql = 'INSERT INTO "2025" (id, time, uid, user, text, sc_amt, sc_ccy) VALUES (?, ?, ?, ?, ?, ?, ?);'
+    params = [123, "2025-01-01", 456, "username", "hello", 15.5, "USD"]
+    resp = asyncio.run(query_cf_d1(sql, params=params))
+    print(resp)
src/networking.py
@@ -114,12 +114,12 @@ async def hx_req(
             return res
     except Exception as e:
         error = f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}"
+        with contextlib.suppress(Exception):
+            error += f"\n{response.json()}"
         if "res" in locals():
             error += f"\n{res}"
         elif "data" in locals():
             error += f"\n{data}"
-        elif "response" in locals():
-            error += f"\n{response}"
         logger.error(error)
         return await hx_req(url, method, headers=headers, cookies=cookies, params=params, data=data, post_json=post_json, proxy=proxy, follow_redirects=follow_redirects, check_keys=check_keys, check_kv=check_kv, timeout=timeout, retry=retry + 1, max_retry=max_retry, silent=silent, rformat=rformat, last_error=error)  # fmt: off