main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3from glom import flatten, glom
4from loguru import logger
5
6from config import DB, PROXY, cache
7from networking import hx_req
8
9
10@cache.memoize(ttl=0)
11async def create_d1_database(
12 name: str = "bennybot",
13 primary_location_hint: str = "",
14 account_id: str = DB.CF_ACCOUNT_ID,
15 api_token: str = DB.CF_API_TOKEN,
16 *,
17 enabled: bool = DB.CF_D1_ENABLED,
18 silent: bool = False,
19) -> str:
20 """Create D1 database and return DatabaseID."""
21 if not all([enabled, account_id, api_token]):
22 return ""
23 api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database"
24 headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
25 payload = {"name": name}
26 # check if database exists
27 resp = await hx_req(api, method="GET", params=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
28 if database_id := glom(resp, "result.0.uuid", default=""):
29 return database_id
30 if primary_location_hint:
31 payload |= {"primary_location_hint": primary_location_hint}
32 resp = await hx_req(api, method="POST", json_data=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
33 return glom(resp, "result.uuid", default="")
34
35
36@cache.memoize(ttl=0)
37async def create_d1_table(
38 table_name: str | float,
39 columns: str,
40 *,
41 idx_cols: list[str] | None = None,
42 idx_prefix: str = "idx_",
43 fts_on_col: str | None = None,
44 fts_index_col: str = "segmented",
45 fts_name: str | None = None,
46 db_name: str = "bennybot",
47 silent: bool = False,
48) -> None:
49 """Create a D1 table.
50
51 If `idx_cols` is provided, create indexs for these columns.
52
53 idx_cols should be a list of strings, the created index names prefixed by `idx_prefix`
54 for example:
55 idx_prefix = "idx_"
56 idx_cols = ["uid", "time"]
57 indexs = ["idx_uid", "idx_time"]
58
59 # create FTS table for Chinese search
60 If `fts_on_col` is provided, create a FTS5 table with `fts_on_col` as the on column.
61 the `fts_index_col` is the column used for FTS5 indexing.
62 """
63 database_id = await create_d1_database(db_name, silent=silent)
64 if not database_id:
65 return
66 tables = await list_d1_tables(db_name, silent=silent)
67 if table_name not in tables:
68 sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'
69 resp = await query_d1(sql, database_id, silent=silent)
70 if resp.get("success"):
71 logger.success(f"Create Table {table_name} in D1 database {db_name}")
72
73 # create indexs if idx_cols is not None
74 if idx_cols is not None:
75 resp = await query_d1("SELECT name FROM sqlite_master WHERE type='index';", db_id=database_id, silent=silent)
76 indexs = glom(resp, "result.0.results.*.name", default=[])
77 for idx_name in idx_cols:
78 if idx_name not in columns:
79 logger.warning(f"Index {idx_name} not in columns {columns}")
80 continue
81 if f"{idx_prefix}{idx_name}" not in indexs:
82 resp = await query_d1(f'CREATE INDEX IF NOT EXISTS "{idx_prefix}{idx_name}" ON "{table_name}"({idx_name})', db_id=database_id, silent=silent)
83 if resp.get("success"):
84 logger.success(f'Create Index "{idx_prefix}{idx_name}" of table "{table_name}" in D1 database "{db_name}"')
85
86 if fts_on_col is not None:
87 # 列出所有虚拟表
88 resp = await query_d1('SELECT name FROM pragma_table_list WHERE type="virtual";', db_id=database_id, silent=silent)
89 virtual_tables = flatten(glom(resp, "result.*.results.*.name", default=[]))
90 """创建 FTS5 虚拟表
91 -- content=table_name 指明关联的原表
92 -- content_rowid=fts_on_col 指明原表的行 ID 列是 fts_on_col
93 -- fts_index_col 是我们要索引的列
94 -- tokenize='unicode61' 使用 unicode61 分词器, 对多种语言支持更好
95 """
96 fts_table = f"fts_{table_name}" if fts_name is None else f"fts_{fts_name}"
97
98 if fts_table not in virtual_tables:
99 logger.debug(f"Creating FTS5 virtual table for {table_name}")
100 sql = f"CREATE VIRTUAL TABLE IF NOT EXISTS '{fts_table}' USING fts5({fts_index_col}, content='{table_name}', content_rowid={fts_on_col}, tokenize='unicode61');"
101 await query_d1(sql, db_id=database_id, silent=silent)
102
103 """将现有数据从原表复制到 FTS 表
104 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=fts_on_col 指定的列) 和 content
105 从原表中选择 fts_on_col 和 segmented 列。fts_on_col 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
106 """
107 sql = f"INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT {fts_on_col}, {fts_index_col} FROM '{table_name}' WHERE {fts_on_col} NOT IN (SELECT rowid FROM '{fts_table}');"
108 await query_d1(sql, db_id=database_id, silent=silent)
109
110 # 列出所有触发器
111 resp = await query_d1('SELECT name FROM sqlite_master WHERE type="trigger";', db_id=database_id, silent=silent)
112 triggers = flatten(glom(resp, "result.*.results.*.name", default=[]))
113 """维护 FTS 表
114 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
115 在原表插入、删除、更新时, 同步更新 FTS 表
116 """
117 trigger_prefix = f"trigger_{table_name}" if fts_name is None else f"trigger_{fts_name}"
118 # 创建触发器, 在原表插入数据时, 同步从 FTS 表插入
119 if f"{trigger_prefix}_ai" not in triggers:
120 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ai' AFTER INSERT ON '{table_name}' BEGIN INSERT INTO '{fts_table}' (rowid, {fts_index_col}) VALUES (NEW.{fts_on_col}, NEW.{fts_index_col}); END;"
121 await query_d1(sql, db_id=database_id, silent=silent)
122
123 # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
124 if f"{trigger_prefix}_ad" not in triggers:
125 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ad' AFTER DELETE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col}; END;"
126 await query_d1(sql, db_id=database_id, silent=silent)
127
128 # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
129 # FTS5 的更新通常是先删除旧的, 再插入新的
130 if f"{trigger_prefix}_au" not in triggers:
131 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_au' AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col} AND OLD.{fts_index_col} <> NEW.{fts_index_col}; INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT NEW.{fts_on_col}, NEW.{fts_index_col} WHERE OLD.{fts_index_col} <> NEW.{fts_index_col}; END;"
132 await query_d1(sql, db_id=database_id, silent=silent)
133
134
135@cache.memoize(ttl=600)
136async def list_d1_tables(db_name: str = "bennybot", *, silent: bool = False) -> list[str]:
137 """List D1 tables in a database."""
138 database_id = await create_d1_database(db_name, silent=silent)
139 if not database_id:
140 return []
141 sql = "SELECT name FROM sqlite_master WHERE type='table';"
142 resp = await query_d1(sql, database_id, silent=silent)
143 return flatten(glom(resp, "result.*.results.*.name", default=[]))
144
145
146async def query_d1(
147 sql: str,
148 db_id: str | None = None,
149 db_name: str = "bennybot",
150 params: list[str] | None = None,
151 account_id: str = DB.CF_ACCOUNT_ID,
152 api_token: str = DB.CF_API_TOKEN,
153 *,
154 enabled: bool = DB.CF_D1_ENABLED,
155 silent: bool = False,
156) -> dict:
157 """Query D1."""
158 if not all([enabled, account_id, api_token]):
159 return {}
160 if db_id is None:
161 db_id = await create_d1_database(db_name, silent=silent)
162 api = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database/{db_id}/query"
163 headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
164 payload = {"sql": sql}
165 if params is not None:
166 payload |= {"params": params}
167 if not silent:
168 logger.trace(f"Query CF-D1: {payload}")
169 return await hx_req(api, "POST", json_data=payload, headers=headers, check_kv={"success": True}, proxy=PROXY.D1, max_retry=0, silent=silent)
170
171
172def insert_d1(table_name: str, records: dict, update_on_conflict: str = "") -> dict:
173 """Create a D1 insert SQL.
174
175 Returns:
176 dict: {"sql": sql, "params": params}
177 """
178 keys = ", ".join(records)
179 values = ", ".join(["?" for _ in range(len(records))])
180 sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values});"
181 if update_on_conflict:
182 updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != update_on_conflict])
183 sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values}) ON CONFLICT ({update_on_conflict}) DO UPDATE SET {updates};"
184 return {"sql": sql, "params": list(records.values())}