main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3from glom import flatten, glom
  4from loguru import logger
  5
  6from config import DB, PROXY, cache
  7from networking import hx_req
  8
  9
 10@cache.memoize(ttl=0)
 11async def create_d1_database(
 12    name: str = "bennybot",
 13    primary_location_hint: str = "",
 14    account_id: str = DB.CF_ACCOUNT_ID,
 15    api_token: str = DB.CF_API_TOKEN,
 16    *,
 17    enabled: bool = DB.CF_D1_ENABLED,
 18    silent: bool = False,
 19) -> str:
 20    """Create D1 database and return DatabaseID."""
 21    if not all([enabled, account_id, api_token]):
 22        return ""
 23    api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database"
 24    headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
 25    payload = {"name": name}
 26    # check if database exists
 27    resp = await hx_req(api, method="GET", params=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
 28    if database_id := glom(resp, "result.0.uuid", default=""):
 29        return database_id
 30    if primary_location_hint:
 31        payload |= {"primary_location_hint": primary_location_hint}
 32    resp = await hx_req(api, method="POST", json_data=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
 33    return glom(resp, "result.uuid", default="")
 34
 35
 36@cache.memoize(ttl=0)
 37async def create_d1_table(
 38    table_name: str | float,
 39    columns: str,
 40    *,
 41    idx_cols: list[str] | None = None,
 42    idx_prefix: str = "idx_",
 43    fts_on_col: str | None = None,
 44    fts_index_col: str = "segmented",
 45    fts_name: str | None = None,
 46    db_name: str = "bennybot",
 47    silent: bool = False,
 48) -> None:
 49    """Create a D1 table.
 50
 51    If `idx_cols` is provided, create indexs for these columns.
 52
 53    idx_cols should be a list of strings, the created index names prefixed by `idx_prefix`
 54    for example:
 55        idx_prefix = "idx_"
 56        idx_cols = ["uid", "time"]
 57        indexs = ["idx_uid", "idx_time"]
 58
 59    # create FTS table for Chinese search
 60    If `fts_on_col` is provided, create a FTS5 table with `fts_on_col` as the on column.
 61    the `fts_index_col` is the column used for FTS5 indexing.
 62    """
 63    database_id = await create_d1_database(db_name, silent=silent)
 64    if not database_id:
 65        return
 66    tables = await list_d1_tables(db_name, silent=silent)
 67    if table_name not in tables:
 68        sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
 69        resp = await query_d1(sql, database_id, silent=silent)
 70        if resp.get("success"):
 71            logger.success(f"Create Table {table_name} in D1 database {db_name}")
 72
 73    # create indexs if idx_cols is not None
 74    if idx_cols is not None:
 75        resp = await query_d1("SELECT name FROM sqlite_master WHERE type='index';", db_id=database_id, silent=silent)
 76        indexs = glom(resp, "result.0.results.*.name", default=[])
 77        for idx_name in idx_cols:
 78            if idx_name not in columns:
 79                logger.warning(f"Index {idx_name} not in columns {columns}")
 80                continue
 81            if f"{idx_prefix}{idx_name}" not in indexs:
 82                resp = await query_d1(f'CREATE INDEX IF NOT EXISTS "{idx_prefix}{idx_name}" ON "{table_name}"({idx_name})', db_id=database_id, silent=silent)
 83                if resp.get("success"):
 84                    logger.success(f'Create Index "{idx_prefix}{idx_name}" of table "{table_name}" in D1 database "{db_name}"')
 85
 86    if fts_on_col is not None:
 87        # 列出所有虚拟表
 88        resp = await query_d1('SELECT name FROM pragma_table_list WHERE type="virtual";', db_id=database_id, silent=silent)
 89        virtual_tables = flatten(glom(resp, "result.*.results.*.name", default=[]))
 90        """创建 FTS5 虚拟表
 91        -- content=table_name 指明关联的原表
 92        -- content_rowid=fts_on_col 指明原表的行 ID 列是 fts_on_col
 93        -- fts_index_col 是我们要索引的列
 94        -- tokenize='unicode61' 使用 unicode61 分词器, 对多种语言支持更好
 95        """
 96        fts_table = f"fts_{table_name}" if fts_name is None else f"fts_{fts_name}"
 97
 98        if fts_table not in virtual_tables:
 99            logger.debug(f"Creating FTS5 virtual table for {table_name}")
100            sql = f"CREATE VIRTUAL TABLE IF NOT EXISTS '{fts_table}' USING fts5({fts_index_col}, content='{table_name}', content_rowid={fts_on_col}, tokenize='unicode61');"
101            await query_d1(sql, db_id=database_id, silent=silent)
102
103            """将现有数据从原表复制到 FTS 表
104            注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=fts_on_col 指定的列) 和 content
105            从原表中选择 fts_on_col 和 segmented 列。fts_on_col 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
106            """
107            sql = f"INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT {fts_on_col}, {fts_index_col} FROM '{table_name}' WHERE {fts_on_col} NOT IN (SELECT rowid FROM '{fts_table}');"
108            await query_d1(sql, db_id=database_id, silent=silent)
109
110        # 列出所有触发器
111        resp = await query_d1('SELECT name FROM sqlite_master WHERE type="trigger";', db_id=database_id, silent=silent)
112        triggers = flatten(glom(resp, "result.*.results.*.name", default=[]))
113        """维护 FTS 表
114        为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
115        在原表插入、删除、更新时, 同步更新 FTS 表
116        """
117        trigger_prefix = f"trigger_{table_name}" if fts_name is None else f"trigger_{fts_name}"
118        # 创建触发器, 在原表插入数据时, 同步从 FTS 表插入
119        if f"{trigger_prefix}_ai" not in triggers:
120            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ai' AFTER INSERT ON '{table_name}' BEGIN INSERT INTO '{fts_table}' (rowid, {fts_index_col}) VALUES (NEW.{fts_on_col}, NEW.{fts_index_col}); END;"
121            await query_d1(sql, db_id=database_id, silent=silent)
122
123        # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
124        if f"{trigger_prefix}_ad" not in triggers:
125            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ad' AFTER DELETE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col}; END;"
126            await query_d1(sql, db_id=database_id, silent=silent)
127
128        # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
129        # FTS5 的更新通常是先删除旧的, 再插入新的
130        if f"{trigger_prefix}_au" not in triggers:
131            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_au' AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col} AND OLD.{fts_index_col} <> NEW.{fts_index_col}; INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT NEW.{fts_on_col}, NEW.{fts_index_col} WHERE OLD.{fts_index_col} <> NEW.{fts_index_col}; END;"
132            await query_d1(sql, db_id=database_id, silent=silent)
133
134
135@cache.memoize(ttl=600)
136async def list_d1_tables(db_name: str = "bennybot", *, silent: bool = False) -> list[str]:
137    """List D1 tables in a database."""
138    database_id = await create_d1_database(db_name, silent=silent)
139    if not database_id:
140        return []
141    sql = "SELECT name FROM sqlite_master WHERE type='table';"
142    resp = await query_d1(sql, database_id, silent=silent)
143    return flatten(glom(resp, "result.*.results.*.name", default=[]))
144
145
146async def query_d1(
147    sql: str,
148    db_id: str | None = None,
149    db_name: str = "bennybot",
150    params: list[str] | None = None,
151    account_id: str = DB.CF_ACCOUNT_ID,
152    api_token: str = DB.CF_API_TOKEN,
153    *,
154    enabled: bool = DB.CF_D1_ENABLED,
155    silent: bool = False,
156) -> dict:
157    """Query D1."""
158    if not all([enabled, account_id, api_token]):
159        return {}
160    if db_id is None:
161        db_id = await create_d1_database(db_name, silent=silent)
162    api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database/{db_id}/query"
163    headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
164    payload = {"sql": sql}
165    if params is not None:
166        payload |= {"params": params}
167    if not silent:
168        logger.trace(f"Query CF-D1: {payload}")
169    return await hx_req(api, "POST", json_data=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
170
171
172def insert_d1(table_name: str, records: dict, update_on_conflict: str = "") -> dict:
173    """Create a D1 insert SQL.
174
175    Returns:
176        dict: {"sql": sql, "params": params}
177    """
178    keys = ", ".join(records)
179    values = ", ".join(["?" for _ in range(len(records))])
180    sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values});"
181    if update_on_conflict:
182        updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != update_on_conflict])
183        sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values}) ON CONFLICT ({update_on_conflict}) DO UPDATE SET {updates};"
184    return {"sql": sql, "params": list(records.values())}