Commit ffe943d

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-06-03 07:23:40
chore(database): use R2 to check cache
1 parent f89cc34
src/database/database.py
@@ -1,43 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from config import DB
-from database.kv import del_cf_kv, get_cf_kv, set_cf_kv
-from database.memory import del_memory_cache, get_memory_cache, set_memory_cache
-from database.r2 import del_cf_r2, get_cf_r2, set_cf_r2
-
-
-async def get_db(key: str) -> dict:
-    """Get data from database."""
-    if not key:
-        return {}
-    if kv := get_memory_cache(key):
-        return kv
-    if DB.ENGINE == "Cloudflare-KV":
-        return await get_cf_kv(key)
-    if DB.ENGINE == "Cloudflare-R2":
-        return await get_cf_r2(key)
-    return {}
-
-
-async def set_db(key: str, data: dict, ttl: int | None = None, metadata: dict | None = None) -> bool:
-    """Set data to database."""
-    success = False
-    if DB.ENGINE == "Cloudflare-KV":
-        success = await set_cf_kv(key, data, ttl=ttl)
-    if DB.ENGINE == "Cloudflare-R2":
-        success = await set_cf_r2(key, data, metadata, ttl=ttl)
-    if success:
-        set_memory_cache(key, data, ttl)
-    return success
-
-
-async def del_db(key: str):
-    """Delete data from database."""
-    if not key:
-        return
-    del_memory_cache(key)
-    if DB.ENGINE == "Cloudflare-KV":
-        await del_cf_kv(key)
-    if DB.ENGINE == "Cloudflare-R2":
-        await del_cf_r2(key)
src/database/memory.py
@@ -1,35 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-from urllib.parse import quote_plus, unquote_plus
-
-from loguru import logger
-
-from config import cache
-
-
-def get_memory_cache(key: str, *, silent: bool = False) -> dict:
-    """Get from memory cache."""
-    key = quote_plus(unquote_plus(key))
-    if kv := cache.get(key):
-        if not silent:
-            logger.trace(f"GET DB from memory cache for {key}: {kv}")
-        return kv
-    return {}
-
-
-def set_memory_cache(key: str, data: dict | list | str, ttl: int | None = None, *, silent: bool = False) -> None:
-    """Set to memory cache."""
-    if ttl is None:
-        ttl = 600
-    key = quote_plus(unquote_plus(key))
-    cache.set(key, data, ttl=ttl)
-    if not silent:
-        logger.trace(f"SET DB to memory cache for {key}: {data}")
-
-
-def del_memory_cache(key: str, *, silent: bool = False):
-    """Delete from memory cache."""
-    key = quote_plus(unquote_plus(key))
-    cache.delete(key)
-    if not silent:
-        logger.trace(f"DEL DB from memory cache for {key}")
src/database/README.md
@@ -1,15 +0,0 @@
-# Databases
-
-All methods:
-
-```py
-from database.database import get_db, set_db, del_db
-from database.kv import get_cf_kv, set_cf_kv, del_cf_kv
-from database.r2 import list_cf_r2, get_cf_r2, set_cf_r2, del_cf_r2
-from database.memory import get_memory_cache, set_memory_cache, del_memory_cache
-from database.d1 import create_d1_database, create_d1_table, list_d1_tables, query_d1
-from database.alist import list_alist, download_alist, upload_alist, delete_alist
-from database.pastbin import upload_pastbin, delete_pastbin
-from database.uguu import upload_uguu
-from database.turso import turso_db_url, turso_create_table, turso_list_tables, turso_exec
-```
src/messages/database.py
@@ -8,8 +8,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message, ReplyParameters
 
-from config import DB
-from database.database import del_db, get_db, set_db
+from database.r2 import del_cf_r2, get_cf_r2, set_cf_r2
 from messages.parser import get_thread_id, parse_msg
 from messages.progress import modify_progress
 from messages.utils import sender_markdown_to_html
@@ -35,11 +34,11 @@ async def save_messages(messages: list[Message | None], key: str, metadata: dict
     if not metadata:
         metadata = {}
     if not messages:
-        logger.error(f"Skip save messages to {DB.ENGINE} due to empty message list")
+        logger.error("Skip save messages to R2 due to empty message list")
         return False
     valid_messages = [x for x in messages if isinstance(x, Message)]
     if len(valid_messages) != len(messages):
-        logger.warning(f"Skip save messages to {DB.ENGINE} due to invalid message type")
+        logger.warning("Skip save messages to R2 due to invalid message type")
         return False
     time_str = valid_messages[0].date.isoformat()
     metadata["time"] = time_str
@@ -79,9 +78,9 @@ async def save_messages(messages: list[Message | None], key: str, metadata: dict
             logger.trace(f"Saving document message {msg.id}")
             data.append({"type": "document"} | msg_extra)
             continue
-        logger.warning(f"Skip save message {msg.id} to {DB.ENGINE} due to unknown type: {msg}")
+        logger.warning(f"Skip save message {msg.id} to R2 due to unknown type: {msg}")
     if data:
-        return await set_db(key, metadata=metadata, data={"data": data})
+        return await set_cf_r2(key, metadata=metadata, data={"data": data})
     return False
 
 
@@ -136,14 +135,14 @@ async def copy_messages_from_db(
     reply_parameters = ReplyParameters(message_id=target_mid)
     tid = get_thread_id(message)
     if kv is None:
-        kv = await get_db(key)
+        kv = await get_cf_r2(key)
     if not kv.get("data"):
-        logger.error(f"Wrong {DB.ENGINE} data for key={key}: {kv}")
+        logger.error(f"Wrong R2 data for key={key}: {kv}")
         return []
     data: list[dict] = kv.get("data", [])
     if isinstance(data, str):
         data = json.loads(data)
-    logger.debug(f"Sending {len(data)} messages from {DB.ENGINE}: {data}")
+    logger.debug(f"Sending {len(data)} messages from R2: {data}")
     results: list[Message] = []
     try:
         for idx, item in enumerate(sorted(data, key=custom_sort)):
@@ -190,12 +189,12 @@ async def copy_messages_from_db(
             else:
                 logger.warning(f"Unknown message type: {item}")
     except Exception as e:
-        logger.error(f"Failed to copy messages for key={key} from {DB.ENGINE}: {e}")
-        await del_db(key)
+        logger.error(f"Failed to copy messages for key={key} from R2: {e}")
+        await del_cf_r2(key)
         return []
     if all(isinstance(x, Message) for x in results):
-        logger.success(f"Successfully copied {len(results)} messages for key={key} from {DB.ENGINE}")
+        logger.success(f"Successfully copied {len(results)} messages for key={key} from R2")
         await modify_progress(del_status=True, **kwargs)
         return results
-    await del_db(key)
+    await del_cf_r2(key)
     return []
src/messages/main.py
@@ -13,7 +13,7 @@ from asr.voice_recognition import voice_to_text
 from bridge.ocr import send_to_ocr_bridge
 from config import FAVORITE, PREFIX, PROXY
 from danmu.entrypoint import query_danmu
-from database.database import del_db
+from database.r2 import del_cf_r2
 from history.query import query_chat_history
 from messages.details import show_msg_info
 from messages.help import social_media_help
@@ -256,7 +256,7 @@ async def preview_social_media(
             logger.success(f"Matched: {matched}")
         kwargs |= matched
         if startswith_prefix(this_texts, prefix="/retry"):
-            await del_db(matched["db_key"])
+            await del_cf_r2(matched["db_key"])
 
         if douyin and matched["platform"] == "douyin":
             return await preview_douyin(client, message, **kwargs)
src/preview/bilibili.py
@@ -17,9 +17,9 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import DB, READING_SPEED, TZ, cache
+from config import READING_SPEED, TZ, cache
 from cookies import bilibili_cookie_dict
-from database.database import get_db
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -47,11 +47,11 @@ async def preview_bilibili(
         db_key (str, optional): The cache key.
         post_id (str, optional): bilibili post ID
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"Bilibili preview {DB.ENGINE} cache hit for key={url}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Bilibili preview cache hit for key={url}")
         if await copy_messages_from_db(client, message, key=url, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析B站链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/douyin.py
@@ -14,8 +14,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from bridge.social import send_to_social_media_bridge
-from config import API, DB, DOWNLOAD_DIR, PROVIDER, PROXY, TOKEN, TZ
-from database.database import get_db
+from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TOKEN, TZ
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -51,11 +51,11 @@ async def preview_douyin(
         douyin_provider (str, optional): The douyin extractor: "direct", "free", "tikhub", "bridge", or combined strings.
         douyin_comments_provider (str, optional): The douyin comments extractor: "free", "tikhub" or "free-tikhub".
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"{platform} preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"{platform} preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析抖音链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/instagram.py
@@ -9,8 +9,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from bridge.social import send_to_social_media_bridge
-from config import API, DB, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN
-from database.database import get_db
+from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -47,11 +47,11 @@ async def preview_instagram(
         instagram_provider (str, optional): The instagram extractor: tikhub, ddinstagram, bridge
         instagram_comments (bool, optional): Add instagram comments. Defaults to True.
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"Instagram preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Instagram preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
 
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析Instagram链接\n{url}", **kwargs)
src/preview/reddit.py
@@ -10,8 +10,8 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import DB, PROXY, TZ
-from database.database import get_db
+from config import PROXY, TZ
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -30,11 +30,11 @@ async def preview_reddit(client: Client, message: Message, url: str = "", db_key
         url (str, optional): Reddit link
         db_key (str, optional): The cache key.
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"Reddit preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Reddit preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析Reddit链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/twitter.py
@@ -11,8 +11,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from bridge.social import send_to_social_media_bridge
-from config import API, DB, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
-from database.database import get_db
+from config import API, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -51,11 +51,11 @@ async def preview_twitter(
         twitter_provider (str): The extractor to use: fxtwitter or tikhub.
         twitter_comments (bool, optional): Add twitter comments. Defaults to True
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"Twitter preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Twitter preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析推特链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/wechat.py
@@ -8,8 +8,8 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, PROXY, TEXT_LENGTH, TOKEN
-from database.database import get_db
+from config import API, CAPTION_LENGTH, DOWNLOAD_DIR, PROXY, TEXT_LENGTH, TOKEN
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -28,11 +28,11 @@ async def preview_wechat(client: Client, message: Message, url: str = "", db_key
         url (str, optional): wechat link
         db_key (str, optional): The cache key.
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"WeChat preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"WeChat preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析微信链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/weibo.py
@@ -14,9 +14,9 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from bridge.social import send_to_social_media_bridge
-from config import API, DB, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
+from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
 from cookies import get_weibo_cookies
-from database.database import get_db
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -58,11 +58,11 @@ async def preview_weibo(
 
     real_post_id = real_weibo_post_id(post_id)
     db_key = db_key.replace(post_id, real_post_id)
-    if kv := await get_db(db_key):
-        logger.debug(f"Weibo preview {DB.ENGINE} cache hit for key={url}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Weibo preview cache hit for key={url}")
         if await copy_messages_from_db(client, message, key=url, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析微博链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/preview/xiaohongshu.py
@@ -11,8 +11,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from bridge.social import send_to_social_media_bridge
-from config import DB, PROVIDER, PROXY, TZ
-from database.database import get_db
+from config import PROVIDER, PROXY, TZ
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -50,11 +50,11 @@ async def preview_xhs(
         is_xhs_link (bool, optional): Whether the link is a share link from APP.
         xhs_provider (str, optional): The xiaohongshu provider.
     """
-    if kv := await get_db(db_key):
-        logger.debug(f"Xiaohongshu preview {DB.ENGINE} cache hit for key={db_key}")
+    if kv := await get_cf_r2(db_key):
+        logger.debug(f"Xiaohongshu preview cache hit for key={db_key}")
         if await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"🔗正在解析小红书链接\n{url}", **kwargs)
         kwargs["progress"] = res[0]
src/ytdlp/main.py
@@ -11,8 +11,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from ai.summary import summarize
-from config import AI, ASR, CAPTION_LENGTH, DB, MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES
-from database.database import get_db
+from config import AI, ASR, CAPTION_LENGTH, MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES
+from database.r2 import get_cf_r2
 from messages.database import copy_messages_from_db, save_messages
 from messages.preprocess import preprocess_media
 from messages.progress import modify_progress, telegram_uploading
@@ -83,12 +83,12 @@ async def preview_ytdlp(
     logger.trace(f"{url=} {kwargs=}")
     # try cache
     db_key = url
-    if true(use_db) and (kv := await get_db(db_key)):
-        logger.debug(f"YT-DLP preview {DB.ENGINE} cache hit for key={db_key}")
+    if true(use_db) and (kv := await get_cf_r2(db_key)):
+        logger.debug(f"YT-DLP preview cache hit for key={db_key}")
         kwargs |= {"copy_video_msg": kwargs.get("copy_video_msg", ytdlp_send_video), "copy_audio_msg": kwargs.get("copy_audio_msg", ytdlp_send_audio)}
         if db_msgs := await copy_messages_from_db(client, message, key=db_key, kv=kv, **kwargs):
             return db_msgs
-        logger.warning(f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...")
+        logger.warning("❌从缓存中转发失败, 尝试重新解析...")
 
     if kwargs.get("show_progress") and not kwargs.get("progress"):
         res = await send2tg(client, message, texts=f"🔗正在解析链接\n{url}", **kwargs)
src/config.py
@@ -200,7 +200,6 @@ class TID:  # see more TID usecase in `src/permission.py`
 
 
 class DB:
-    ENGINE = os.getenv("DB_ENGINE", "Cloudflare-R2")
     CF_KV_ENABLED = os.getenv("CF_KV_ENABLED", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     CF_ACCOUNT_ID = os.getenv("CF_ACCOUNT_ID", "")
     CF_API_TOKEN = os.getenv("CF_API_TOKEN", "")