Commit c576d27
Changed files (32)
src
history
messages
subtitles
src/asr/ali_asr.py
@@ -14,7 +14,8 @@ from loguru import logger
from asr.utils import downsampe_audio
from config import ASR, DB, FILE_SERVER
-from database import delete_alist, upload_alist, upload_uguu
+from database.alist import delete_alist, upload_alist
+from database.uguu import upload_uguu
from networking import hx_req
src/asr/tecent_asr.py
@@ -12,7 +12,8 @@ from loguru import logger
from asr.utils import downsampe_audio, is_english_word
from config import ASR, FILE_SERVER
-from database import delete_alist, upload_alist, upload_uguu
+from database.alist import delete_alist, upload_alist
+from database.uguu import upload_uguu
from networking import hx_req
from utils import nowdt
src/danmu/r2.py
@@ -10,7 +10,7 @@ from loguru import logger
from config import DANMU, DOWNLOAD_DIR, TZ
from danmu.utils import live_date
-from database import get_cf_r2
+from database.r2 import get_cf_r2
from messages.progress import modify_progress
from others.emoji import CURRENCY
from utils import number
src/danmu/sync.py
@@ -11,7 +11,8 @@ from loguru import logger
from config import DANMU, TZ, cache
from danmu.utils import merge_json, simplify_json
-from database import create_d1_table, get_cf_r2, list_cf_r2, query_d1, set_cf_r2
+from database.d1 import create_d1_table, query_d1
+from database.r2 import get_cf_r2, list_cf_r2, set_cf_r2
from networking import hx_req
from utils import nowdt
src/database/alist.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import base64
+import io
+import shutil
+from pathlib import Path
+
+import anyio
+from glom import glom
+from loguru import logger
+
+from config import DB, DOWNLOAD_DIR
+from networking import hx_req
+
+
+async def list_alist() -> list[dict]:
+ """List from Alist."""
+ if not DB.ALIST_ENABLED:
+ return [{}]
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/list"
+ logger.info(f"List Alist from: {api}")
+ res = await hx_req(api, method="POST", post_json={"path": "/"}, check_kv={"code": 200})
+ return glom(res, "data.content", default=[]) or []
+
+
+async def download_alist(fname: str, save_path: str | Path | None = None) -> str:
+ """Download file from Alist."""
+ if not DB.ALIST_ENABLED:
+ return ""
+ save_path = Path(DOWNLOAD_DIR) / fname if save_path is None else Path(save_path)
+ ext = Path(fname).suffix
+ # Headers DO NOT support Unicode characters
+ if any(ord(c) > 127 for c in fname): # has Non-ASCII
+ b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
+ fname = b64str + ext
+ url = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{fname.lstrip('/')}"
+ res = await hx_req(url, rformat="content", check_keys=["content"])
+ logger.info(f"Download {fname} to {save_path.name} from: {url}")
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(res["content"])
+ return save_path.as_posix()
+
+
+async def upload_alist(path: str | Path) -> str:
+ """Upload to Alist."""
+ if not DB.ALIST_ENABLED:
+ return ""
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return ""
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/form"
+ # Headers DO NOT support Unicode characters
+ if any(ord(c) > 127 for c in path.name): # has Non-ASCII
+ new_name = base64.urlsafe_b64encode(path.name.encode("utf-8")).decode("ascii").rstrip("=")
+ new_path = path.with_stem(new_name)
+ shutil.copy(path, new_path)
+ path = new_path
+ headers = {"File-Path": f"/{path.name}"}
+ logger.info(f"Upload {path.name} to: {api}")
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ await hx_req(api, method="PUT", headers=headers, files={"file": (path.name, io.BytesIO(content))}, check_kv={"code": 200})
+ if "new_path" in locals():
+ new_path.unlink(missing_ok=True)
+ return DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{path.name}"
+
+
+async def delete_alist(fname: str, *, ensure_ascii: bool = True) -> None:
+ """Delete from Alist."""
+ if not DB.ALIST_ENABLED:
+ return
+ # Get JWT Token
+ payload = {"username": DB.ALIST_USERNAME, "password": DB.ALIST_PASSWORD}
+ auth_api = DB.ALIST_SERVER.removesuffix("/") + "/api/auth/login"
+ res = await hx_req(auth_api, method="POST", post_json=payload, check_keys=["data.token"])
+ token = res["data"]["token"]
+
+ # Delete
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/remove"
+ headers = {"Content-Type": "application/json", "Authorization": token}
+ if ensure_ascii and any(ord(c) > 127 for c in fname): # has Non-ASCII
+ b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
+ payload = {"names": [b64str + Path(fname).suffix]}
+ else:
+ payload = {"names": [fname]}
+ logger.info(f"Delete {fname} from: {api}")
+ res = await hx_req(api, method="POST", headers=headers, post_json=payload, check_kv={"code": 200})
src/database/d1.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from glom import flatten, glom
+from loguru import logger
+
+from config import DB, PROXY, cache
+from networking import hx_req
+
+
+@cache.memoize(ttl=0)
+async def create_d1_database(
+ name: str = "bennybot",
+ primary_location_hint: str = "",
+ account_id: str = DB.CF_ACCOUNT_ID,
+ api_token: str = DB.CF_API_TOKEN,
+ *,
+ enabled: bool = DB.CF_D1_ENABLED,
+ silent: bool = False,
+) -> str:
+ """Create D1 database and return DatabaseID."""
+ if not all([enabled, account_id, api_token]):
+ return ""
+ api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database"
+ headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
+ payload = {"name": name}
+ # check if database exists
+ resp = await hx_req(api, method="GET", params=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
+ if database_id := glom(resp, "result.0.uuid", default=""):
+ return database_id
+ if primary_location_hint:
+ payload |= {"primary_location_hint": primary_location_hint}
+ resp = await hx_req(api, method="POST", post_json=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
+ return glom(resp, "result.uuid", default="")
+
+
+@cache.memoize(ttl=0)
+async def create_d1_table(table_name: str | float, columns: str, db_name: str = "bennybot", *, silent: bool = False) -> None:
+ """Create D1 database and return DatabaseID."""
+ database_id = await create_d1_database(db_name, silent=silent)
+ if not database_id:
+ return
+ tables = await list_d1_tables(db_name, silent=silent)
+ if table_name in tables:
+ return
+
+ sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
+ await query_d1(sql, database_id, silent=silent)
+ if not silent:
+ logger.success(f"Create Table {table_name} in D1 database {db_name}")
+
+
+@cache.memoize(ttl=600)
+async def list_d1_tables(db_name: str = "bennybot", *, silent: bool = False) -> list[str]:
+ """List D1 tables in a database."""
+ database_id = await create_d1_database(db_name, silent=silent)
+ if not database_id:
+ return []
+ sql = "SELECT name FROM sqlite_master WHERE type='table';"
+ resp = await query_d1(sql, database_id, silent=silent)
+ return flatten(glom(resp, "result.*.results.*.name", default=[]))
+
+
+async def query_d1(
+ sql: str,
+ db_id: str | None = None,
+ db_name: str = "bennybot",
+ params: list[str] | None = None,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ api_token: str = DB.CF_API_TOKEN,
+ *,
+ enabled: bool = DB.CF_D1_ENABLED,
+ silent: bool = False,
+) -> dict:
+ """Query D1."""
+ if not all([enabled, account_id, api_token]):
+ return {}
+ if db_id is None:
+ db_id = await create_d1_database(db_name, silent=silent)
+ api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database/{db_id}/query"
+ headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
+ payload = {"sql": sql}
+ if params is not None:
+ payload |= {"params": params}
+ if not silent:
+ logger.trace(f"Query CF-D1: {payload}")
+ return await hx_req(api, "POST", post_json=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
src/database/database.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from config import DB
+from database.kv import del_cf_kv, get_cf_kv, set_cf_kv
+from database.memory import del_memory_cache, get_memory_cache, set_memory_cache
+from database.r2 import del_cf_r2, get_cf_r2, set_cf_r2
+
+
+async def get_db(key: str) -> dict:
+ """Get data from database."""
+ if not key:
+ return {}
+ if kv := get_memory_cache(key):
+ return kv
+ if DB.ENGINE == "Cloudflare-KV":
+ return await get_cf_kv(key)
+ if DB.ENGINE == "Cloudflare-R2":
+ return await get_cf_r2(key)
+ return {}
+
+
+async def set_db(key: str, data: dict, ttl: int | None = None, metadata: dict | None = None) -> bool:
+ """Set data to database."""
+ success = False
+ if DB.ENGINE == "Cloudflare-KV":
+ success = await set_cf_kv(key, data, ttl=ttl)
+ if DB.ENGINE == "Cloudflare-R2":
+ success = await set_cf_r2(key, data, metadata, ttl=ttl)
+ if success:
+ set_memory_cache(key, data, ttl)
+ return success
+
+
+async def del_db(key: str):
+ """Delete data from database."""
+ if not key:
+ return
+ del_memory_cache(key)
+ if DB.ENGINE == "Cloudflare-KV":
+ await del_cf_kv(key)
+ if DB.ENGINE == "Cloudflare-R2":
+ await del_cf_r2(key)
src/database/kv.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from urllib.parse import quote_plus, unquote_plus
+
+from httpx import AsyncClient, AsyncHTTPTransport
+from loguru import logger
+
+from config import DB
+from networking import hx_req
+
+
+async def get_cf_kv(
+ key: str,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ api_token: str = DB.CF_API_TOKEN,
+ namespace_id: str = DB.CF_KV_NAMESPACE_ID,
+ *,
+ enabled: bool = DB.CF_D1_ENABLED,
+ silent: bool = False,
+) -> dict:
+ """Get from Cloudflare KV."""
+ if not all([enabled, account_id, namespace_id, api_token]):
+ return {}
+ key = quote_plus(unquote_plus(key))
+ api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/storage/kv/namespaces/{namespace_id}/values/{key}"
+ headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
+ resp = await hx_req(api, headers=headers, timeout=30, silent=True)
+ async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
+ try:
+ resp = await hx.get(api, headers=headers, timeout=30)
+ if resp.status_code == 404 and not silent:
+ logger.trace(f"404 Not Found for CF-KV key={key}")
+ return {}
+ resp.raise_for_status()
+ if data := resp.json():
+ if not silent:
+ logger.success(f"GET CF-KV for {key}: {data}")
+ return data
+ except Exception as e:
+ logger.warning(f"GET CF-KV failed for {key}: {e}")
+ return {}
+
+
+async def set_cf_kv(
+ key: str,
+ data: dict,
+ ttl: int | None = None,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ api_token: str = DB.CF_API_TOKEN,
+ namespace_id: str = DB.CF_KV_NAMESPACE_ID,
+ *,
+ enabled: bool = DB.CF_D1_ENABLED,
+ silent: bool = False,
+) -> bool:
+ """Set to Cloudflare KV."""
+ if not all([enabled, account_id, namespace_id, api_token]):
+ return False
+ key = quote_plus(unquote_plus(key))
+ api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/storage/kv/namespaces/{namespace_id}/values/{key}"
+ if ttl is not None:
+ api = f"{api}?expiration_ttl={ttl}"
+ headers = {"authorization": f"Bearer {api_token}", "content-type": "*/*"}
+ resp = await hx_req(api, method="PUT", headers=headers, data=data, timeout=30, silent=True, check_kv={"success": True})
+ if error := resp.get("hx_error"):
+ logger.warning(f"SET CF-KV failed for key={key}: {error}")
+ return False
+ if not silent:
+ logger.success(f"Successfully SET CF-KV for key={key}: {data}")
+ return True
+
+
+async def del_cf_kv(
+ key: str,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ api_token: str = DB.CF_API_TOKEN,
+ namespace_id: str = DB.CF_KV_NAMESPACE_ID,
+ *,
+ enabled: bool = DB.CF_D1_ENABLED,
+ silent: bool = False,
+):
+ """Delete from Cloudflare KV."""
+ key = quote_plus(unquote_plus(key))
+ if not all([enabled, account_id, namespace_id, api_token]):
+ return
+ api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/storage/kv/namespaces/{namespace_id}/values/{key}"
+ headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
+ async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
+ try:
+ resp = await hx.delete(api, headers=headers, timeout=30)
+ resp.raise_for_status()
+ except Exception as e:
+ logger.warning(f"DEL CF-KV failed for key={key}: {e}")
+ return
+ if resp.json().get("success") and not silent:
+ logger.success(f"DEL CF-KV for key={key}")
+ return
src/database/memory.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from urllib.parse import quote_plus, unquote_plus
+
+from loguru import logger
+
+from config import cache
+
+
+def get_memory_cache(key: str, *, silent: bool = False) -> dict:
+ """Get from memory cache."""
+ key = quote_plus(unquote_plus(key))
+ if kv := cache.get(key):
+ if not silent:
+ logger.trace(f"GET DB from memory cache for {key}: {kv}")
+ return kv
+ return {}
+
+
+def set_memory_cache(key: str, data: dict | list | str, ttl: int | None = None, *, silent: bool = False) -> None:
+ """Set to memory cache."""
+ if ttl is None:
+ ttl = 600
+ key = quote_plus(unquote_plus(key))
+ cache.set(key, data, ttl=ttl)
+ if not silent:
+ logger.trace(f"SET DB to memory cache for {key}: {data}")
+
+
+def del_memory_cache(key: str, *, silent: bool = False):
+ """Delete from memory cache."""
+ key = quote_plus(unquote_plus(key))
+ cache.delete(key)
+ if not silent:
+ logger.trace(f"DEL DB from memory cache for {key}")
src/database/pastbin.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import io
+from pathlib import Path
+
+import anyio
+from httpx import AsyncClient, AsyncHTTPTransport
+from loguru import logger
+
+from config import DB
+from networking import hx_req
+
+
+async def upload_pastbin(path: str | Path, ttl: int | str | None = None) -> tuple[str, str]:
+ """Upload to Pastbin Workders.
+
+ Returns:
+ download url, manage url
+
+ https://github.com/SharzyL/pastebin-worker
+ """
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return "", ""
+ if path.stat().st_size > DB.PASTBIN_MAX_BYTES:
+ logger.warning(f"File size exceeds {DB.PASTBIN_MAX_BYTES} bytes, skipping: {path.name}")
+ return "", ""
+ logger.debug(f"Uploading {path.name} to: {DB.PASTBIN_SERVER}")
+
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ payload = {"c": (path.name, io.BytesIO(content)), "p": "1"}
+ if ttl is not None:
+ payload |= {"e": str(ttl)}
+ res = await hx_req(DB.PASTBIN_SERVER, method="POST", files=payload, check_keys=["url", "manageUrl"])
+ url = res["url"]
+ logger.success(f"Uploaded {path.name} to {url}")
+ return url, res["manageUrl"]
+
+
+async def delete_pastbin(url: str):
+ """Delete file in Pastbin Workders.
+
+ https://github.com/SharzyL/pastebin-worker
+ """
+ if not url:
+ return
+ async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
+ try:
+ resp = await hx.delete(url, timeout=30)
+ resp.raise_for_status()
+ except Exception as e:
+ logger.warning(f"DEL Pastbin failed for url={url}: {e}")
+ return
+ else:
+ logger.success(f"DEL Pastbin for url={url}")
src/database/r2.py
@@ -0,0 +1,191 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import json
+import os
+import warnings
+from datetime import timedelta
+from urllib.parse import unquote_plus
+
+import brotli
+from aioboto3 import Session
+from botocore.exceptions import ClientError
+from loguru import logger
+
+from config import DB
+from utils import bare_url, nowdt, stringfy
+
+# hot fix: https://developers.cloudflare.com/r2/examples/aws/boto3/
+os.environ["AWS_REQUEST_CHECKSUM_CALCULATION"] = "when_required"
+os.environ["AWS_RESPONSE_CHECKSUM_VALIDATION"] = "when_required"
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+async def list_cf_r2(
+ prefix: str = "",
+ continuation_token: str | None = None,
+ bucket_name: str = DB.CF_R2_BUCKET_NAME,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ aws_access_key_id: str = DB.CF_R2_ACCESS_KEY_ID,
+ aws_secret_access_key: str = DB.CF_R2_SECRET_ACCESS_KEY,
+ *,
+ enabled: bool = DB.CF_R2_ENABLED,
+) -> dict:
+ """Get from Cloudflare R2."""
+ if not all([enabled, bucket_name, account_id, aws_access_key_id, aws_secret_access_key]):
+ return {}
+ async with Session().client(
+ service_name="s3",
+ endpoint_url=f"https://{account_id}.r2.cloudflarestorage.com",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name="auto",
+ ) as s3: # type: ignore
+ payload = {"Bucket": bucket_name, "MaxKeys": 1000}
+ if continuation_token:
+ payload["ContinuationToken"] = continuation_token
+ if prefix:
+ payload["Prefix"] = prefix
+ contents = []
+ try:
+ results = await s3.list_objects_v2(**payload)
+ if not results.get("IsTruncated"):
+ return results
+ contents.extend(results.get("Contents", []))
+ while results.get("NextContinuationToken"):
+ payload["ContinuationToken"] = results["NextContinuationToken"]
+ results = await s3.list_objects_v2(**payload)
+ contents.extend(results.get("Contents", []))
+ results["Contents"] = contents
+ except Exception as e:
+ logger.warning(f"List CF-R2 failed for {prefix=}: {e}")
+ return {}
+ return results
+
+
+async def get_cf_r2(
+ key: str,
+ bucket_name: str = DB.CF_R2_BUCKET_NAME,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ aws_access_key_id: str = DB.CF_R2_ACCESS_KEY_ID,
+ aws_secret_access_key: str = DB.CF_R2_SECRET_ACCESS_KEY,
+ *,
+ enabled: bool = DB.CF_R2_ENABLED,
+ silent: bool = False,
+) -> dict:
+ """Get from Cloudflare R2."""
+ if not all([enabled, bucket_name, account_id, aws_access_key_id, aws_secret_access_key]):
+ return {}
+ key = bare_url(unquote_plus(key)) # remove http(s):// prefix
+ async with Session().client(
+ service_name="s3",
+ endpoint_url=f"https://{account_id}.r2.cloudflarestorage.com",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name="auto",
+ ) as s3: # type: ignore
+ try:
+ obj = await s3.get_object(Bucket=bucket_name, Key=key)
+ if obj.get("Body"):
+ data = await obj["Body"].read()
+ data = json.loads(data)
+ if not silent:
+ logger.success(f"GET CF-R2 for {key}: {data}")
+ return data
+ except ClientError as e:
+ if e.response["Error"]["Code"] != "NoSuchKey":
+ logger.warning(f"GET CF-R2 failed for {key}: {e}")
+ except Exception as e:
+ logger.warning(f"GET CF-R2 failed for {key}: {e}")
+ return {}
+
+
+async def set_cf_r2(
+ key: str,
+ data: dict | list | str | None = None,
+ metadata: dict | None = None,
+ ttl: int | None = None,
+ bucket_name: str = DB.CF_R2_BUCKET_NAME,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ aws_access_key_id: str = DB.CF_R2_ACCESS_KEY_ID,
+ aws_secret_access_key: str = DB.CF_R2_SECRET_ACCESS_KEY,
+ *,
+ compress: bool = False,
+ quality: int = 4,
+ mime_type: str = "application/json",
+ enabled: bool = DB.CF_R2_ENABLED,
+ silent: bool = False,
+) -> bool:
+ """Set to Cloudflare R2."""
+ if not data and not metadata:
+ return False
+ if not all([enabled, bucket_name, account_id, aws_access_key_id, aws_secret_access_key]):
+ return False
+ key = bare_url(unquote_plus(key)) # remove http(s):// prefix
+ payload = {
+ "CacheControl": "no-cache",
+ "Bucket": bucket_name,
+ "Key": key,
+ "ContentType": mime_type,
+ }
+ if data:
+ if isinstance(data, (dict, list)):
+ upload = json.dumps(data, ensure_ascii=False).encode("utf-8")
+ elif isinstance(data, str):
+ upload = data.encode("utf-8")
+ payload["ContentType"] = mime_type if mime_type.startswith("text") else "text/plain"
+ payload |= {"Body": upload}
+
+ if compress:
+ payload["Body"] = brotli.compress(upload, quality=min(quality, 11))
+ payload["ContentEncoding"] = "br"
+
+ metadata = metadata or {}
+ if metadata:
+ payload |= {"Metadata": stringfy(metadata)}
+ if ttl is not None:
+ payload |= {"Expires": nowdt() + timedelta(seconds=ttl)}
+ async with Session().client(
+ service_name="s3",
+ endpoint_url=f"https://{account_id}.r2.cloudflarestorage.com",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name="auto",
+ ) as s3: # type: ignore
+ try:
+ await s3.put_object(**payload)
+ if not silent:
+ logger.success(f"Successfully SET CF-R2 for {key}: {data=}, {metadata=}")
+ except Exception as e:
+ logger.warning(f"SET CF-R2 failed for {key}: {e}")
+ return False
+ return True
+
+
+async def del_cf_r2(
+ key: str,
+ bucket_name: str = DB.CF_R2_BUCKET_NAME,
+ account_id: str = DB.CF_ACCOUNT_ID,
+ aws_access_key_id: str = DB.CF_R2_ACCESS_KEY_ID,
+ aws_secret_access_key: str = DB.CF_R2_SECRET_ACCESS_KEY,
+ *,
+ enabled: bool = DB.CF_R2_ENABLED,
+ silent: bool = False,
+):
+ """Delete from Cloudflare R2."""
+ if not all([enabled, bucket_name, account_id, aws_access_key_id, aws_secret_access_key]):
+ return
+ key = bare_url(unquote_plus(key)) # remove http(s):// prefix
+ async with Session().client(
+ service_name="s3",
+ endpoint_url=f"https://{account_id}.r2.cloudflarestorage.com",
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name="auto",
+ ) as s3: # type: ignore
+ try:
+ await s3.delete_object(Bucket=bucket_name, Key=key)
+ if not silent:
+ logger.success(f"DEL CF-R2 for key={key}")
+ except Exception as e:
+ logger.warning(f"DEL CF-R2 failed for key={key}: {e}")
src/database/README.md
@@ -0,0 +1,14 @@
+# Databases
+
+All methods:
+
+```py
+from database.database import get_db, set_db, del_db
+from database.kv import get_cf_kv, set_cf_kv, del_cf_kv
+from database.r2 import list_cf_r2, get_cf_r2, set_cf_r2, del_cf_r2
+from database.memory import get_memory_cache, set_memory_cache, del_memory_cache
+from database.d1 import create_d1_database, create_d1_table, list_d1_tables, query_d1
+from database.alist import list_alist, download_alist, upload_alist, delete_alist
+from database.pastbin import upload_pastbin, delete_pastbin
+from database.uguu import upload_uguu
+```
src/database/uguu.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import io
+from pathlib import Path
+
+import anyio
+from loguru import logger
+
+from networking import hx_req
+from utils import guess_mime
+
+
+async def upload_uguu(path: str | Path) -> str:
+ """Upload to https://uguu.se, return the download url."""
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return ""
+ api = "https://uguu.se/upload"
+ logger.debug(f"Uploading {path.name} to: https://Uguu.se")
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ res = await hx_req(api, method="POST", files={"files[]": (path.name, io.BytesIO(content), guess_mime(path))}, check_kv={"success": True}, check_keys=["files.0.url"])
+ url = res["files"][0]["url"]
+ logger.success(f"Uploaded {path.name} to {url}")
+ return url
src/history/sync.py
@@ -14,7 +14,7 @@ from pyrogram.errors import PeerIdInvalid
from pyrogram.types import Chat, Message
from config import DOWNLOAD_DIR, HISTORY, TZ, cache
-from database import create_d1_database, create_d1_table, query_d1
+from database.d1 import create_d1_database, create_d1_table, query_d1
from messages.parser import parse_msg
from permission import check_save_d1
from utils import i_am_bot
src/messages/database.py
@@ -9,7 +9,7 @@ from pyrogram.client import Client
from pyrogram.types import Message, ReplyParameters
from config import DB
-from database import del_db, get_db, set_db
+from database.database import del_db, get_db, set_db
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.utils import sender_markdown_to_html
src/others/download_external.py
@@ -7,7 +7,6 @@ from pyrogram.client import Client
from pyrogram.types import Message
from config import MAX_FILE_BYTES, PREFIX
-from database import guess_mime
from llm.utils import convert_md
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -16,7 +15,7 @@ from messages.utils import equal_prefix, get_reply_to, startswith_prefix
from multimedia import is_valid_video_or_audio, validate_img
from networking import download_file
from publish import publish_telegraph
-from utils import find_url, readable_size, to_int
+from utils import find_url, guess_mime, readable_size, to_int
HELP = f"""
⏬**下载文件**
src/others/podcast.py
@@ -20,7 +20,8 @@ from pyrogram.types.messages_and_media.message import Str
from asr.voice_recognition import asr_file
from config import DB, DOWNLOAD_DIR, PODCAST, PREFIX, READING_SPEED, TZ, cache
-from database import get_cf_r2, set_cf_r2, upload_alist
+from database.alist import upload_alist
+from database.r2 import get_cf_r2, set_cf_r2
from llm.gpt import gpt_response
from llm.utils import convert_html, convert_md, remove_consecutive_newlines
from messages.sender import send2tg
src/preview/bilibili.py
@@ -8,7 +8,7 @@ from pyrogram.types import Message
from config import DB, PROXY, TELEGRAM_UA, cache
from cookies import cookie_cloud_bilibili
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/douyin.py
@@ -12,7 +12,7 @@ from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
from config import API, DB, PROVIDER, PROXY, TOKEN, TZ
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/instagram.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
from config import API, DB, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/reddit.py
@@ -12,7 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM
from pyrogram.types import Message
from config import DB, PROXY, TZ
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/twitter.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
from config import API, DB, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/wechat.py
@@ -10,7 +10,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXP
from pyrogram.types import Message
from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, PROXY, TEXT_LENGTH, TOKEN
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/weibo.py
@@ -17,7 +17,7 @@ from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
from config import API, DB, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
from cookies import get_weibo_cookies
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/xiaohongshu.py
@@ -11,7 +11,7 @@ from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
from config import DB, PROXY, TZ
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
src/preview/ytdlp.py
@@ -35,7 +35,7 @@ from config import (
YTDLP_RE_ENCODING_MAX_FILE_BYTES,
cache,
)
-from database import get_db
+from database.database import get_db
from messages.database import copy_messages_from_db, save_messages
from messages.preprocess import preprocess_media
from messages.progress import modify_progress, telegram_uploading
src/subtitles/subtitle.py
@@ -11,8 +11,7 @@ from pyrogram.types import Message
from pyrogram.types.messages_and_media.message import Str
from asr.voice_recognition import asr_file
-from config import PREFIX, PROVIDER, READING_SPEED, TEXT_LENGTH
-from database import cache
+from config import PREFIX, PROVIDER, READING_SPEED, TEXT_LENGTH, cache
from llm.gpt import gpt_response
from messages.parser import parse_msg
from messages.progress import modify_progress
src/database.py
@@ -1,562 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""Currently, Memory cache, Cloudflare KV and Cloudflare R2 are supported.
-
-Note: Memory Cache is always enabled.
-"""
-
-import base64
-import contextlib
-import io
-import json
-import os
-import shutil
-import warnings
-from datetime import timedelta
-from pathlib import Path
-from urllib.parse import quote_plus, unquote_plus
-
-import anyio
-import brotli
-import puremagic
-from aioboto3 import Session
-from botocore.exceptions import ClientError
-from glom import flatten, glom
-from httpx import AsyncClient, AsyncHTTPTransport
-from loguru import logger
-
-from config import DB, DOWNLOAD_DIR, PROXY, cache
-from networking import hx_req
-from utils import bare_url, nowdt, stringfy
-
-# hot fix: https://developers.cloudflare.com/r2/examples/aws/boto3/
-os.environ["AWS_REQUEST_CHECKSUM_CALCULATION"] = "when_required"
-os.environ["AWS_RESPONSE_CHECKSUM_VALIDATION"] = "when_required"
-
-warnings.filterwarnings("ignore", category=DeprecationWarning)
-
-
-async def get_db(key: str) -> dict:
- """Get KV."""
- if not key:
- return {}
- key = quote_plus(key)
- if kv := get_memory_kv(key):
- return kv
- if DB.ENGINE == "Cloudflare-KV":
- return await get_cf_kv(key)
- if DB.ENGINE == "Cloudflare-R2":
- return await get_cf_r2(key)
- return {}
-
-
-def get_memory_kv(key: str) -> dict:
- """Get from memory cache."""
- if kv := cache.get(key):
- logger.trace(f"GET KV from memory cache for {key}: {kv}")
- return kv
- return {}
-
-
-async def get_cf_kv(key: str, *, log_success: bool = True) -> dict:
- """Get from Cloudflare KV."""
- if not DB.CF_KV_ENABLED:
- logger.warning("SKIP GET CF-KV: Cloudflare KV disabled")
- return {}
- key = quote_plus(key)
- api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/storage/kv/namespaces/{DB.CF_KV_NAMESPACE_ID}/values/{key}"
- headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
- async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
- try:
- resp = await hx.get(api, headers=headers, timeout=30)
- if resp.status_code == 404:
- logger.trace(f"404 Not Found for CF-KV key={key}")
- return {}
- resp.raise_for_status()
- if data := resp.json():
- if log_success:
- logger.success(f"GET CF-KV for {key}: {data}")
- return data
- except Exception as e:
- logger.warning(f"GET CF-KV failed for {key}: {e}")
- return {}
-
-
-async def list_cf_r2(prefix: str = "", continuation_token: str | None = None) -> dict:
- """Get from Cloudflare R2."""
- if not DB.CF_R2_ENABLED:
- logger.warning("SKIP LIST CF-R2: Cloudflare R2 disabled")
- return {}
- async with Session().client(
- service_name="s3",
- endpoint_url=f"https://{DB.CF_ACCOUNT_ID}.r2.cloudflarestorage.com",
- aws_access_key_id=DB.CF_R2_ACCESS_KEY_ID,
- aws_secret_access_key=DB.CF_R2_SECRET_ACCESS_KEY,
- region_name="auto",
- ) as s3: # type: ignore
- payload = {"Bucket": DB.CF_R2_BUCKET_NAME, "MaxKeys": 1000}
- if continuation_token:
- payload["ContinuationToken"] = continuation_token
- if prefix:
- payload["Prefix"] = prefix
- contents = []
- try:
- results = await s3.list_objects_v2(**payload)
- if not results.get("IsTruncated"):
- return results
- contents.extend(results.get("Contents", []))
- while results.get("NextContinuationToken"):
- payload["ContinuationToken"] = results["NextContinuationToken"]
- results = await s3.list_objects_v2(**payload)
- contents.extend(results.get("Contents", []))
- results["Contents"] = contents
- except Exception as e:
- logger.warning(f"List CF-R2 failed for {prefix=}: {e}")
- return {}
- return results
-
-
-async def get_cf_r2(key: str, *, silent: bool = False) -> dict:
- """Get from Cloudflare R2."""
- if not DB.CF_R2_ENABLED:
- logger.warning("SKIP GET CF-R2: Cloudflare R2 disabled")
- return {}
-
- key = bare_url(unquote_plus(key)) # remove http(s):// prefix
- async with Session().client(
- service_name="s3",
- endpoint_url=f"https://{DB.CF_ACCOUNT_ID}.r2.cloudflarestorage.com",
- aws_access_key_id=DB.CF_R2_ACCESS_KEY_ID,
- aws_secret_access_key=DB.CF_R2_SECRET_ACCESS_KEY,
- region_name="auto",
- ) as s3: # type: ignore
- try:
- obj = await s3.get_object(Bucket=DB.CF_R2_BUCKET_NAME, Key=key)
- if obj.get("Body"):
- data = await obj["Body"].read()
- data = json.loads(data)
- if not silent:
- logger.success(f"GET CF-R2 for {key}: {data}")
- return data
- except ClientError as e:
- if e.response["Error"]["Code"] != "NoSuchKey":
- logger.warning(f"GET CF-R2 failed for {key}: {e}")
- except Exception as e:
- logger.warning(f"GET CF-R2 failed for {key}: {e}")
- return {}
-
-
-async def set_db(key: str, data: dict | list, ttl: int | None = None, metadata: dict | None = None) -> bool:
- """Set KV."""
- key = quote_plus(key)
- success = False
- if DB.ENGINE == "Cloudflare-KV":
- success = await set_cf_kv(key, data, ttl=ttl)
- if DB.ENGINE == "Cloudflare-R2":
- success = await set_cf_r2(key, data, metadata, ttl=ttl)
- if success:
- set_memory_kv(key, data, ttl)
- return success
-
-
-def set_memory_kv(key: str, data: dict | list | str, ttl: int | None = None) -> None:
- """Set to memory cache."""
- if ttl is None:
- ttl = 600
- cache.set(key, data, ttl=ttl)
- logger.trace(f"SET KV to memory cache for {key}: {data}")
-
-
-async def set_cf_kv(key: str, data: dict | list | str, ttl: int | None = None, *, skip_in_memory: bool = True, log_success: bool = True) -> bool:
- """Set to Cloudflare KV.
-
- If `skip_in_memory` is True, it will skip setting to CF-KV if the key is already in memory cache.
- """
- if skip_in_memory and cache.get(key):
- logger.trace(f"SKIP SET CF-KV: key is already in memory cache for {key}: {cache.get(key)}")
- return True
- if not DB.CF_KV_ENABLED:
- logger.warning("SKIP SET CF-KV: Cloudflare KV disabled")
- return True
- key = quote_plus(key)
- api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/storage/kv/namespaces/{DB.CF_KV_NAMESPACE_ID}/values/{key}"
- if ttl is not None:
- api = f"{api}?expiration_ttl={ttl}"
- headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "*/*"}
- async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
- try:
- resp = await hx.put(api, headers=headers, json=data, timeout=30)
- resp.raise_for_status()
- except Exception as e:
- logger.warning(f"Failed to SET CF-KV for key={key}: {e}")
- return False
- if resp.json().get("success"):
- if log_success:
- logger.success(f"Successfully SET CF-KV for key={key}: {data}")
- return True
- return False
-
-
-async def set_cf_r2(
- key: str,
- data: dict | list | str | None = None,
- metadata: dict | None = None,
- ttl: int | None = None,
- *,
- compress: bool = False,
- quality: int = 4,
- mime_type: str = "application/json",
- skip_in_memory: bool = True,
- silent: bool = False,
-) -> bool:
- """Set to Cloudflare R2 via boto3.
-
- We do not put data to R2, just use metadata to store data.
-
- If `skip_in_memory` is True, it will skip setting to CF-R2 if the key is already in memory cache.
- """
- if not data and not metadata:
- return False
- if skip_in_memory and cache.get(key):
- logger.trace(f"SKIP SET CF-R2: key is already in memory cache for {key}: {cache.get(key)}")
- return True
- if not DB.CF_R2_ENABLED:
- logger.warning("SKIP SET CF-R2: Cloudflare R2 disabled")
- return True
- key = bare_url(unquote_plus(key)) # remove http(s):// prefix
- payload = {
- "CacheControl": "no-cache",
- "Bucket": DB.CF_R2_BUCKET_NAME,
- "Key": key,
- "ContentType": mime_type,
- }
- if data:
- if isinstance(data, (dict, list)):
- upload = json.dumps(data, ensure_ascii=False).encode("utf-8")
- elif isinstance(data, str):
- upload = data.encode("utf-8")
- payload["ContentType"] = mime_type if mime_type.startswith("text") else "text/plain"
- payload |= {"Body": upload}
-
- if compress:
- payload["Body"] = brotli.compress(upload, quality=min(quality, 11))
- payload["ContentEncoding"] = "br"
-
- metadata = metadata or {}
- if metadata:
- payload |= {"Metadata": stringfy(metadata)}
- if ttl is not None:
- payload |= {"Expires": nowdt() + timedelta(seconds=ttl)}
- async with Session().client(
- service_name="s3",
- endpoint_url=f"https://{DB.CF_ACCOUNT_ID}.r2.cloudflarestorage.com",
- aws_access_key_id=DB.CF_R2_ACCESS_KEY_ID,
- aws_secret_access_key=DB.CF_R2_SECRET_ACCESS_KEY,
- region_name="auto",
- ) as s3: # type: ignore
- try:
- await s3.put_object(**payload)
- if not silent:
- logger.success(f"Successfully SET CF-R2 for {key}: {data=}, {metadata=}")
- except Exception as e:
- logger.warning(f"SET CF-R2 failed for {key}: {e}")
- return False
- return True
-
-
-async def del_db(key: str):
- """Delete KV."""
- if not key:
- return
- del_memory_kv(key)
- if DB.ENGINE == "Cloudflare-KV":
- await del_cf_kv(key)
- if DB.ENGINE == "Cloudflare-R2":
- await del_cf_r2(key)
-
-
-def del_memory_kv(key: str):
- """Delete from memory cache."""
- key = quote_plus(key)
- cache.delete(key)
-
-
-async def del_cf_kv(key: str):
- """Delete from Cloudflare KV."""
- key = quote_plus(key)
- if not DB.CF_KV_ENABLED:
- logger.warning("SKIP SET CF-KV: Cloudflare KV disabled")
- return
- api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/storage/kv/namespaces/{DB.CF_KV_NAMESPACE_ID}/values/{key}"
- headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
- async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
- try:
- resp = await hx.delete(api, headers=headers, timeout=30)
- resp.raise_for_status()
- except Exception as e:
- logger.warning(f"DEL CF-KV failed for key={key}: {e}")
- return
- if resp.json().get("success"):
- logger.success(f"DEL CF-KV for key={key}")
- return
-
-
-async def del_cf_r2(key: str):
- """Delete from Cloudflare R2."""
- if not DB.CF_R2_ENABLED:
- logger.warning("SKIP SET CF-R2: Cloudflare R2 disabled")
- return
- key = bare_url(unquote_plus(key)) # remove http(s):// prefix
- async with Session().client(
- service_name="s3",
- endpoint_url=f"https://{DB.CF_ACCOUNT_ID}.r2.cloudflarestorage.com",
- aws_access_key_id=DB.CF_R2_ACCESS_KEY_ID,
- aws_secret_access_key=DB.CF_R2_SECRET_ACCESS_KEY,
- region_name="auto",
- ) as s3: # type: ignore
- try:
- await s3.delete_object(Bucket=DB.CF_R2_BUCKET_NAME, Key=key)
- logger.success(f"DEL CF-R2 for key={key}")
- except Exception as e:
- logger.warning(f"DEL CF-R2 failed for key={key}: {e}")
- return
- return
-
-
-@cache.memoize(ttl=0)
-async def create_d1_database(name: str = "bennybot", primary_location_hint: str = "", *, silent: bool = False) -> str:
- """Create D1 database and return DatabaseID."""
- if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
- return ""
- api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/d1/database"
- headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
- payload = {"name": name}
- # check if database exists
- resp = await hx_req(api, method="GET", params=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
- if database_id := glom(resp, "result.0.uuid", default=""):
- return database_id
- if primary_location_hint:
- payload |= {"primary_location_hint": primary_location_hint}
- resp = await hx_req(api, method="POST", post_json=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
- return glom(resp, "result.uuid", default="")
-
-
-@cache.memoize(ttl=0)
-async def create_d1_table(table_name: str | float, columns: str, db_name: str = "bennybot", *, silent: bool = False) -> None:
- """Create D1 database and return DatabaseID."""
- if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
- return
- database_id = await create_d1_database(db_name, silent=silent)
- if not database_id:
- return
- tables = await list_d1_tables(db_name, silent=silent)
- if table_name in tables:
- return
-
- sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
- await query_d1(sql, database_id, silent=silent)
- if not silent:
- logger.success(f"Create Table {table_name} in D1 database {db_name}")
-
-
-@cache.memoize(ttl=600)
-async def list_d1_tables(db_name: str = "bennybot", *, silent: bool = False) -> list[str]:
- """List D1 tables in a database."""
- if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
- return []
- database_id = await create_d1_database(db_name, silent=silent)
- if not database_id:
- return []
- sql = "SELECT name FROM sqlite_master WHERE type='table';"
- resp = await query_d1(sql, database_id, silent=silent)
- return flatten(glom(resp, "result.*.results.*.name", default=[]))
-
-
-async def query_d1(sql: str, db_id: str | None = None, db_name: str = "bennybot", params: list[str] | None = None, *, silent: bool = False) -> dict:
- """Query D1."""
- if not all([DB.CF_D1_ENABLED, DB.CF_ACCOUNT_ID, DB.CF_API_TOKEN]):
- return {}
- if db_id is None:
- db_id = await create_d1_database(db_name, silent=silent)
- api = f"https://api.cloudflare.com/client/v4/accounts/{DB.CF_ACCOUNT_ID}/d1/database/{db_id}/query"
- headers = {"authorization": f"Bearer {DB.CF_API_TOKEN}", "content-type": "application/json"}
- payload = {"sql": sql}
- if params is not None:
- payload |= {"params": params}
- if not silent:
- logger.trace(f"Query CF-D1: {payload}")
- return await hx_req(api, "POST", post_json=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
-
-
-async def list_alist() -> list[dict]:
- """List from Alist."""
- if not DB.ALIST_ENABLED:
- return [{}]
- api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/list"
- logger.info(f"List Alist from: {api}")
- res = await hx_req(api, method="POST", post_json={"path": "/"}, check_kv={"code": 200})
- return glom(res, "data.content", default=[]) or []
-
-
-async def download_alist(fname: str, save_path: str | Path | None = None) -> str:
- """Download file from Alist."""
- if not DB.ALIST_ENABLED:
- return ""
- save_path = Path(DOWNLOAD_DIR) / fname if save_path is None else Path(save_path)
- ext = Path(fname).suffix
- # Headers DO NOT support Unicode characters
- if any(ord(c) > 127 for c in fname): # has Non-ASCII
- b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
- fname = b64str + ext
- url = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{fname.lstrip('/')}"
- res = await hx_req(url, rformat="content", check_keys=["content"])
- logger.info(f"Download {fname} to {save_path.name} from: {url}")
- async with await anyio.open_file(save_path, "wb") as f:
- await f.write(res["content"])
- return save_path.as_posix()
-
-
-async def upload_alist(path: str | Path) -> str:
- """Upload to Alist."""
- if not DB.ALIST_ENABLED:
- return ""
- path = Path(path).expanduser().resolve()
- if not path.is_file():
- return ""
- api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/form"
- # Headers DO NOT support Unicode characters
- if any(ord(c) > 127 for c in path.name): # has Non-ASCII
- new_name = base64.urlsafe_b64encode(path.name.encode("utf-8")).decode("ascii").rstrip("=")
- new_path = path.with_stem(new_name)
- shutil.copy(path, new_path)
- path = new_path
- headers = {"File-Path": f"/{path.name}"}
- logger.info(f"Upload {path.name} to: {api}")
- async with await anyio.open_file(path, "rb") as f:
- content = await f.read()
- await hx_req(api, method="PUT", headers=headers, files={"file": (path.name, io.BytesIO(content))}, check_kv={"code": 200})
- if "new_path" in locals():
- new_path.unlink(missing_ok=True)
- return DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{path.name}"
-
-
-async def delete_alist(fname: str, *, ensure_ascii: bool = True) -> None:
- """Delete from Alist."""
- if not DB.ALIST_ENABLED:
- return
- # Get JWT Token
- payload = {"username": DB.ALIST_USERNAME, "password": DB.ALIST_PASSWORD}
- auth_api = DB.ALIST_SERVER.removesuffix("/") + "/api/auth/login"
- res = await hx_req(auth_api, method="POST", post_json=payload, check_keys=["data.token"])
- token = res["data"]["token"]
-
- # Delete
- api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/remove"
- headers = {"Content-Type": "application/json", "Authorization": token}
- if ensure_ascii and any(ord(c) > 127 for c in fname): # has Non-ASCII
- b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
- payload = {"names": [b64str + Path(fname).suffix]}
- else:
- payload = {"names": [fname]}
- logger.info(f"Delete {fname} from: {api}")
- res = await hx_req(api, method="POST", headers=headers, post_json=payload, check_kv={"code": 200})
-
-
-def guess_mime(path: str | Path) -> str:
- path = Path(path).expanduser().resolve()
- if not path.is_file():
- return ""
- with contextlib.suppress(Exception):
- import magic # magic needs `libmagic` to be installed.
-
- # `sudo apt-get install libmagic1` or `brew install libmagic`
- return magic.from_file(path, mime=True)
-
- # infer from `magic` failed
- with contextlib.suppress(Exception):
- return puremagic.from_file(path, mime=True)
- return ""
-
-
-async def upload_uguu(path: str | Path) -> str:
- """Upload to https://uguu.se, return the download url."""
- path = Path(path).expanduser().resolve()
- if not path.is_file():
- return ""
- api = "https://uguu.se/upload"
- logger.debug(f"Uploading {path.name} to: https://Uguu.se")
- async with await anyio.open_file(path, "rb") as f:
- content = await f.read()
- res = await hx_req(api, method="POST", files={"files[]": (path.name, io.BytesIO(content), guess_mime(path))}, check_kv={"success": True}, check_keys=["files.0.url"])
- url = res["files"][0]["url"]
- logger.success(f"Uploaded {path.name} to {url}")
- return url
-
-
-async def upload_pastbin(path: str | Path, ttl: int | str | None = None) -> tuple[str, str]:
- """Upload to Pastbin Workders.
-
- Returns:
- download url, manage url
-
- https://github.com/SharzyL/pastebin-worker
- """
- path = Path(path).expanduser().resolve()
- if not path.is_file():
- return "", ""
- if path.stat().st_size > DB.PASTBIN_MAX_BYTES:
- logger.warning(f"File size exceeds {DB.PASTBIN_MAX_BYTES} bytes, skipping: {path.name}")
- return "", ""
- logger.debug(f"Uploading {path.name} to: {DB.PASTBIN_SERVER}")
-
- async with await anyio.open_file(path, "rb") as f:
- content = await f.read()
- payload = {"c": (path.name, io.BytesIO(content)), "p": "1"}
- if ttl is not None:
- payload |= {"e": str(ttl)}
- res = await hx_req(DB.PASTBIN_SERVER, method="POST", files=payload, check_keys=["url", "manageUrl"])
- url = res["url"]
- logger.success(f"Uploaded {path.name} to {url}")
- return url, res["manageUrl"]
-
-
-async def delete_pastbin(url: str):
- """Delete file in Pastbin Workders.
-
- https://github.com/SharzyL/pastebin-worker
- """
- if not url:
- return
- async with AsyncClient(http2=True, follow_redirects=True, transport=AsyncHTTPTransport(retries=3, http2=True)) as hx:
- try:
- resp = await hx.delete(url, timeout=30)
- resp.raise_for_status()
- except Exception as e:
- logger.warning(f"DEL Pastbin failed for url={url}: {e}")
- return
- else:
- logger.success(f"DEL Pastbin for url={url}")
-
-
-if __name__ == "__main__":
- import asyncio
-
- # url, manage_url = asyncio.run(upload_pastbin("测试.mp3", ttl=10))
- # asyncio.run(delete_pastbin(manage_url))
- # asyncio.run(upload_uguu("测试.mp3"))
- # asyncio.run(list_alist())
- # asyncio.run(upload_alist("测试.mp3"))
- # asyncio.run(download_alist("测试.mp3"))
- # asyncio.run(delete_alist("测试.mp3"))
- # asyncio.run(download_alist("test.py"))
- # asyncio.run(set_cf_r2("test2", metadata={"finished": "1"}))
- # asyncio.run(set_cf_r2("test2", data={"finished": "1"}, ttl=60))
- asyncio.run(list_cf_r2("RSS/"))
- # asyncio.run(get_cf_r2("Danmu/2025-05-26"))
- # columns = "id INTEGER PRIMARY KEY, time TEXT, uid INTEGER, user TEXT, text TEXT, sc_amt REAL NULL, sc_ccy TEXT NULL"
- # asyncio.run(create_d1_table("2025", columns))
-
- # sql = 'INSERT INTO "2025" (id, time, uid, user, text, sc_amt, sc_ccy) VALUES (?, ?, ?, ?, ?, ?, ?);'
- # params = [123, "2025-01-01", 456, "username", "hello", 15.5, "USD"]
- # resp = asyncio.run(query_d1(sql, params=params))
- # print(resp)
src/handler.py
@@ -10,7 +10,7 @@ from asr.voice_recognition import voice_to_text
from bridge.ocr import send_to_ocr_bridge
from config import ENABLE, PREFIX, PROXY
from danmu.entrypoint import query_danmu
-from database import del_db
+from database.database import del_db
from history.sync import sync_history_to_d1
from llm.gpt import gpt_response
from llm.summary import ai_summary
src/networking.py
@@ -11,7 +11,7 @@ from urllib.parse import parse_qs, quote_plus, urlparse
import anyio
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Request, RequestError, Response
-from httpx._types import RequestContent, RequestFiles # type: ignore
+from httpx._types import RequestContent, RequestData, RequestFiles # type: ignore
from httpx_curl_cffi import AsyncCurlTransport, CurlOpt
from loguru import logger
@@ -39,7 +39,7 @@ async def hx_req(
headers: dict | None = None,
cookies: dict | None = None,
params: dict | None = None,
- data: dict | None = None,
+ data: RequestData | None = None,
post_json: dict | None = None,
post_content: RequestContent | None = None,
files: RequestFiles | None = None,
src/publish.py
@@ -14,7 +14,7 @@ from loguru import logger
from telegraph.aio import Telegraph
from config import DB, TOKEN, TZ
-from database import set_cf_r2
+from database.r2 import set_cf_r2
from utils import nowdt, rand_string
src/utils.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import contextlib
import json
import random
import re
@@ -10,6 +11,7 @@ from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo
+import puremagic
import zhconv
from bilibili_api.utils.aid_bvid_transformer import aid2bvid, bvid2aid
from bs4.element import PageElement
@@ -400,6 +402,22 @@ def is_supported_by_ytdlp(url: str) -> bool:
return any(extractor.suitable(url) for extractor in extractors)
+def guess_mime(path: str | Path) -> str:
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return ""
+ with contextlib.suppress(Exception):
+ import magic # magic needs `libmagic` to be installed.
+
+ # `sudo apt-get install libmagic1` or `brew install libmagic`
+ return magic.from_file(path, mime=True)
+
+ # infer from `magic` failed
+ with contextlib.suppress(Exception):
+ return puremagic.from_file(path, mime=True)
+ return ""
+
+
def unicode_to_ascii(text: str | float) -> str:
if not text:
return ""