main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3from glom import flatten, glom
4from loguru import logger
5
6from config import DB, PROXY, cache
7from networking import hx_req
8
9
10@cache.memoize(ttl=0)
11async def turso_db_url(
12 db_name: str = "bennybot",
13 group: str = "default",
14 *,
15 username: str = DB.TURSO_USERNAME,
16 api_token: str = DB.TURSO_API_TOKEN,
17 enabled: bool = DB.TURSO_ENABLED,
18 silent: bool = False,
19) -> str:
20 """Get Turso database url."""
21 if not all([enabled, username, api_token]):
22 return ""
23 api = f"https://api.turso.tech/v1/organizations/{username}/databases"
24 headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
25 payload = {"name": db_name, "group": group}
26 # check if database exists
27 resp = await hx_req(api, method="GET", headers=headers, proxy=PROXY.TURSO, check_keys=["databases"], silent=silent)
28 if db_name in glom(resp, "databases.*.Name", default=[]):
29 return "https://" + next(x["hostname"] for x in resp["databases"] if x["Name"] == db_name) + "/v2/pipeline"
30
31 # not exists, create it
32 resp = await hx_req(api, method="POST", json_data=payload, headers=headers, check_kv={"database.Name": db_name}, check_keys=["database.Hostname"], proxy=PROXY.TURSO, silent=silent)
33 return "https://" + resp["database"]["Hostname"] + "/v2/pipeline"
34
35
36@cache.memoize(ttl=0)
37async def turso_create_table(
38 table_name: str,
39 columns: str,
40 *,
41 idx_cols: list[str] | None = None,
42 idx_prefix: str = "idx_",
43 fts_on_col: str | None = None,
44 fts_index_col: str = "segmented",
45 fts_name: str | None = None,
46 db_name: str = "bennybot",
47 username: str = DB.TURSO_USERNAME,
48 api_token: str = DB.TURSO_API_TOKEN,
49 group_token: str = DB.TURSO_GROUP_TOKEN,
50 silent: bool = False,
51) -> None:
52 """Create a turso table.
53
54 If `idx_cols` is provided, create indexs for these columns.
55
56 idx_cols should be a list of strings, the created index names prefixed by `idx_prefix`
57 for example:
58 idx_prefix = "idx_"
59 idx_cols = ["uid", "time"]
60 indexs = ["idx_uid", "idx_time"]
61
62 # create FTS table for Chinese search
63 If `fts_on_col` is provided, create a FTS5 table with `fts_on_col` as the on column.
64 the `fts_index_col` is the column used for FTS5 indexing.
65 """
66 tables = await turso_list_tables(db_name, username=username, api_token=api_token, group_token=group_token, silent=silent)
67 if table_name not in tables:
68 resp = await turso_exec(
69 [{"type": "execute", "stmt": {"sql": f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'}}],
70 db_name=db_name,
71 username=username,
72 api_token=api_token,
73 group_token=group_token,
74 silent=silent,
75 )
76 if glom(resp, "results.0.type", default="") == "ok":
77 logger.success(f'Create Table "{table_name}" in Turso database "{db_name}"')
78
79 # create indexs if idx_cols is not None
80 if idx_cols is not None:
81 resp = await turso_exec(
82 [{"type": "execute", "stmt": {"sql": "SELECT name FROM sqlite_master WHERE type='index';"}}],
83 db_name=db_name,
84 username=username,
85 api_token=api_token,
86 group_token=group_token,
87 silent=silent,
88 )
89 indexs = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
90 for idx_name in idx_cols:
91 if idx_name not in columns:
92 logger.warning(f"Index {idx_name} not in columns {columns}")
93 continue
94 if f"{idx_prefix}{idx_name}" not in indexs:
95 resp = await turso_exec(
96 [{"type": "execute", "stmt": {"sql": f'CREATE INDEX IF NOT EXISTS "{idx_prefix}{idx_name}" ON "{table_name}"({idx_name})'}}],
97 db_name=db_name,
98 username=username,
99 api_token=api_token,
100 group_token=group_token,
101 silent=silent,
102 )
103 if glom(resp, "results.0.type", default="") == "ok":
104 logger.success(f'Create Index "{idx_prefix}{idx_name}" of table "{table_name}" in Turso database "{db_name}"')
105
106 if fts_on_col is not None:
107 # 列出所有虚拟表
108 resp = await turso_exec(
109 [{"type": "execute", "stmt": {"sql": 'SELECT name FROM pragma_table_list WHERE type="virtual";'}}],
110 db_name=db_name,
111 username=username,
112 api_token=api_token,
113 group_token=group_token,
114 silent=silent,
115 )
116 virtual_tables = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
117
118 """创建 FTS5 虚拟表
119 -- content=table_name 指明关联的原表
120 -- content_rowid=fts_on_col 指明原表的行 ID 列是 fts_on_col
121 -- fts_index_col 是我们要索引的列
122 -- tokenize='unicode61' 使用 unicode61 分词器, 对多种语言支持更好
123 """
124 fts_table = f"fts_{table_name}" if fts_name is None else f"fts_{fts_name}"
125
126 statements = []
127 if fts_table not in virtual_tables:
128 logger.debug(f"Creating FTS5 virtual table for {table_name}")
129 sql = f"CREATE VIRTUAL TABLE IF NOT EXISTS '{fts_table}' USING fts5({fts_index_col}, content='{table_name}', content_rowid={fts_on_col}, tokenize='unicode61');"
130 statements.append({"type": "execute", "stmt": {"sql": sql}})
131
132 """将现有数据从原表复制到 FTS 表
133 注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=fts_on_col 指定的列) 和 content
134 从原表中选择 fts_on_col 和 segmented 列。fts_on_col 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
135 """
136 sql = f"INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT {fts_on_col}, {fts_index_col} FROM '{table_name}' WHERE {fts_on_col} NOT IN (SELECT rowid FROM '{fts_table}');"
137 statements.append({"type": "execute", "stmt": {"sql": sql}})
138
139 # 列出所有触发器
140 resp = await turso_exec(
141 [{"type": "execute", "stmt": {"sql": 'SELECT name FROM sqlite_master WHERE type="trigger";'}}],
142 db_name=db_name,
143 username=username,
144 api_token=api_token,
145 group_token=group_token,
146 silent=silent,
147 )
148 triggers = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
149 """维护 FTS 表
150 为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
151 在原表插入、删除、更新时, 同步更新 FTS 表
152 """
153 trigger_prefix = f"trigger_{table_name}" if fts_name is None else f"trigger_{fts_name}"
154 # 创建触发器, 在原表插入数据时, 同步从 FTS 表插入
155 if f"{trigger_prefix}_ai" not in triggers:
156 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ai' AFTER INSERT ON '{table_name}' BEGIN INSERT INTO '{fts_table}' (rowid, {fts_index_col}) VALUES (NEW.{fts_on_col}, NEW.{fts_index_col}); END;"
157 statements.append({"type": "execute", "stmt": {"sql": sql}})
158
159 # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
160 if f"{trigger_prefix}_ad" not in triggers:
161 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ad' AFTER DELETE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col}; END;"
162 statements.append({"type": "execute", "stmt": {"sql": sql}})
163
164 # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
165 # FTS5 的更新通常是先删除旧的, 再插入新的
166 if f"{trigger_prefix}_au" not in triggers:
167 sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_au' AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col} AND OLD.{fts_index_col} <> NEW.{fts_index_col}; INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT NEW.{fts_on_col}, NEW.{fts_index_col} WHERE OLD.{fts_index_col} <> NEW.{fts_index_col}; END;"
168 statements.append({"type": "execute", "stmt": {"sql": sql}})
169 await turso_exec(
170 statements,
171 db_name=db_name,
172 username=username,
173 api_token=api_token,
174 group_token=group_token,
175 silent=silent,
176 )
177
178
179async def turso_list_tables(
180 db_name: str = "bennybot",
181 *,
182 username: str = DB.TURSO_USERNAME,
183 api_token: str = DB.TURSO_API_TOKEN,
184 group_token: str = DB.TURSO_GROUP_TOKEN,
185 silent: bool = False,
186) -> list[str]:
187 """List turso tables in a database."""
188 resp = await turso_exec(
189 [
190 {
191 "type": "execute",
192 "stmt": {"sql": "SELECT name FROM sqlite_master WHERE type='table';"},
193 }
194 ],
195 db_name=db_name,
196 username=username,
197 api_token=api_token,
198 group_token=group_token,
199 silent=silent,
200 )
201 return flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
202
203
204async def turso_exec(
205 statements: list[dict],
206 *,
207 db_name: str = "bennybot",
208 username: str = DB.TURSO_USERNAME,
209 api_token: str = DB.TURSO_API_TOKEN,
210 group_token: str = DB.TURSO_GROUP_TOKEN,
211 retry: int = 0,
212 silent: bool = False,
213) -> dict:
214 """Exec turso statements."""
215 if not all([statements, db_name, username, api_token, group_token]):
216 return {}
217 db_url = await turso_db_url(db_name, username=username, api_token=api_token, silent=silent)
218 if not db_url:
219 return {}
220 headers = {"authorization": f"Bearer {group_token}", "content-type": "application/json"}
221
222 for stmt in statements:
223 if (sql := stmt.get("sql")) and not sql.endswith(";"):
224 stmt["sql"] += ";"
225
226 if statements[-1] != {"type": "close"}:
227 statements.append({"type": "close"})
228 if not silent:
229 logger.trace(f"Turso Exec: {statements}")
230
231 num_statements = len(statements)
232 resp = await hx_req(db_url, "POST", json_data={"requests": statements}, headers=headers, check_keys=["results"], proxy=PROXY.TURSO, max_retry=int(retry), silent=silent, timeout=600)
233 num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"])
234 if not silent:
235 rows = glom(resp, "results.0.response.result.rows", default=[])
236 log = f"Found {len(rows)} records in Turso."
237 log += f" Rows read: {glom(resp, 'results.0.response.result.rows_read', default=0)}"
238 log += f", write: {glom(resp, 'results.0.response.result.rows_written', default=0)}"
239 logger.success(log)
240 if num_statements != num_success:
241 error = "\n".join(glom(resp, "results.*.error.message", default=[]))
242 logger.error(f"Turso Exec: {num_statements} statements, {num_success} success.\n{error}")
243 return resp
244
245
246def insert_statement(table_name: str, records: dict, update_on_conflict: str = "") -> dict:
247 """Create a turso insert statement."""
248 SQL_TYPES = {"str": "text", "int": "integer", "float": "float", "nonetype": "null"}
249 keys = ", ".join(records)
250 values = ", ".join(["?" for _ in range(len(records))])
251 args = [{"type": SQL_TYPES[type(x).__name__.lower()], "value": str(x) if isinstance(x, int | float) else x} for x in records.values()]
252
253 sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values});"
254 if update_on_conflict:
255 updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != update_on_conflict])
256 sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values}) ON CONFLICT ({update_on_conflict}) DO UPDATE SET {updates};"
257 stmt = {"sql": sql}
258 if args:
259 stmt |= {"args": args}
260 return {"type": "execute", "stmt": stmt}
261
262
263def turso_parse_resp(resp: dict) -> list[dict]:
264 """Parse turso SELECT response."""
265
266 def correct_type(decltype: str, value: str):
267 if str(value) == "":
268 return ""
269 if decltype in ["INT", "INTEGER"]:
270 return int(value)
271 if decltype in ["FLOAT", "REAL"]:
272 return float(value)
273 return value
274
275 cols = glom(resp, "results.0.response.result.cols", default=[])
276 rows = glom(resp, "results.0.response.result.rows", default=[])
277 return [{col["name"]: correct_type(col["decltype"], x.get("value", "")) for x, col in zip(row, cols, strict=True)} for row in rows]