main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3from glom import flatten, glom
  4from loguru import logger
  5
  6from config import DB, PROXY, cache
  7from networking import hx_req
  8
  9
 10@cache.memoize(ttl=0)
 11async def turso_db_url(
 12    db_name: str = "bennybot",
 13    group: str = "default",
 14    *,
 15    username: str = DB.TURSO_USERNAME,
 16    api_token: str = DB.TURSO_API_TOKEN,
 17    enabled: bool = DB.TURSO_ENABLED,
 18    silent: bool = False,
 19) -> str:
 20    """Get Turso database url."""
 21    if not all([enabled, username, api_token]):
 22        return ""
 23    api = f"https://api.turso.tech/v1/organizations/{username}/databases"
 24    headers = {"authorization": f"Bearer {api_token}", "content-type": "application/json"}
 25    payload = {"name": db_name, "group": group}
 26    # check if database exists
 27    resp = await hx_req(api, method="GET", headers=headers, proxy=PROXY.TURSO, check_keys=["databases"], silent=silent)
 28    if db_name in glom(resp, "databases.*.Name", default=[]):
 29        return "https://" + next(x["hostname"] for x in resp["databases"] if x["Name"] == db_name) + "/v2/pipeline"
 30
 31    # not exists, create it
 32    resp = await hx_req(api, method="POST", json_data=payload, headers=headers, check_kv={"database.Name": db_name}, check_keys=["database.Hostname"], proxy=PROXY.TURSO, silent=silent)
 33    return "https://" + resp["database"]["Hostname"] + "/v2/pipeline"
 34
 35
 36@cache.memoize(ttl=0)
 37async def turso_create_table(
 38    table_name: str,
 39    columns: str,
 40    *,
 41    idx_cols: list[str] | None = None,
 42    idx_prefix: str = "idx_",
 43    fts_on_col: str | None = None,
 44    fts_index_col: str = "segmented",
 45    fts_name: str | None = None,
 46    db_name: str = "bennybot",
 47    username: str = DB.TURSO_USERNAME,
 48    api_token: str = DB.TURSO_API_TOKEN,
 49    group_token: str = DB.TURSO_GROUP_TOKEN,
 50    silent: bool = False,
 51) -> None:
 52    """Create a turso table.
 53
 54    If `idx_cols` is provided, create indexs for these columns.
 55
 56    idx_cols should be a list of strings, the created index names prefixed by `idx_prefix`
 57    for example:
 58        idx_prefix = "idx_"
 59        idx_cols = ["uid", "time"]
 60        indexs = ["idx_uid", "idx_time"]
 61
 62    # create FTS table for Chinese search
 63    If `fts_on_col` is provided, create a FTS5 table with `fts_on_col` as the on column.
 64    the `fts_index_col` is the column used for FTS5 indexing.
 65    """
 66    tables = await turso_list_tables(db_name, username=username, api_token=api_token, group_token=group_token, silent=silent)
 67    if table_name not in tables:
 68        resp = await turso_exec(
 69            [{"type": "execute", "stmt": {"sql": f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns});'}}],
 70            db_name=db_name,
 71            username=username,
 72            api_token=api_token,
 73            group_token=group_token,
 74            silent=silent,
 75        )
 76        if glom(resp, "results.0.type", default="") == "ok":
 77            logger.success(f'Create Table "{table_name}" in Turso database "{db_name}"')
 78
 79    # create indexs if idx_cols is not None
 80    if idx_cols is not None:
 81        resp = await turso_exec(
 82            [{"type": "execute", "stmt": {"sql": "SELECT name FROM sqlite_master WHERE type='index';"}}],
 83            db_name=db_name,
 84            username=username,
 85            api_token=api_token,
 86            group_token=group_token,
 87            silent=silent,
 88        )
 89        indexs = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
 90        for idx_name in idx_cols:
 91            if idx_name not in columns:
 92                logger.warning(f"Index {idx_name} not in columns {columns}")
 93                continue
 94            if f"{idx_prefix}{idx_name}" not in indexs:
 95                resp = await turso_exec(
 96                    [{"type": "execute", "stmt": {"sql": f'CREATE INDEX IF NOT EXISTS "{idx_prefix}{idx_name}" ON "{table_name}"({idx_name})'}}],
 97                    db_name=db_name,
 98                    username=username,
 99                    api_token=api_token,
100                    group_token=group_token,
101                    silent=silent,
102                )
103                if glom(resp, "results.0.type", default="") == "ok":
104                    logger.success(f'Create Index "{idx_prefix}{idx_name}" of table "{table_name}" in Turso database "{db_name}"')
105
106    if fts_on_col is not None:
107        # 列出所有虚拟表
108        resp = await turso_exec(
109            [{"type": "execute", "stmt": {"sql": 'SELECT name FROM pragma_table_list WHERE type="virtual";'}}],
110            db_name=db_name,
111            username=username,
112            api_token=api_token,
113            group_token=group_token,
114            silent=silent,
115        )
116        virtual_tables = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
117
118        """创建 FTS5 虚拟表
119        -- content=table_name 指明关联的原表
120        -- content_rowid=fts_on_col 指明原表的行 ID 列是 fts_on_col
121        -- fts_index_col 是我们要索引的列
122        -- tokenize='unicode61' 使用 unicode61 分词器, 对多种语言支持更好
123        """
124        fts_table = f"fts_{table_name}" if fts_name is None else f"fts_{fts_name}"
125
126        statements = []
127        if fts_table not in virtual_tables:
128            logger.debug(f"Creating FTS5 virtual table for {table_name}")
129            sql = f"CREATE VIRTUAL TABLE IF NOT EXISTS '{fts_table}' USING fts5({fts_index_col}, content='{table_name}', content_rowid={fts_on_col}, tokenize='unicode61');"
130            statements.append({"type": "execute", "stmt": {"sql": sql}})
131
132            """将现有数据从原表复制到 FTS 表
133            注意, 我们在这里插入的是 rowid (它会对应到 content_rowid=fts_on_col 指定的列) 和 content
134            从原表中选择 fts_on_col 和 segmented 列。fts_on_col 列的值会被插入到 FTS 表中对应原表 rowid (或 content_rowid) 的位置。
135            """
136            sql = f"INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT {fts_on_col}, {fts_index_col} FROM '{table_name}' WHERE {fts_on_col} NOT IN (SELECT rowid FROM '{fts_table}');"
137            statements.append({"type": "execute", "stmt": {"sql": sql}})
138
139        # 列出所有触发器
140        resp = await turso_exec(
141            [{"type": "execute", "stmt": {"sql": 'SELECT name FROM sqlite_master WHERE type="trigger";'}}],
142            db_name=db_name,
143            username=username,
144            api_token=api_token,
145            group_token=group_token,
146            silent=silent,
147        )
148        triggers = flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
149        """维护 FTS 表
150        为了让 FTS 表与原表保持同步, 需要在原表上创建触发器。
151        在原表插入、删除、更新时, 同步更新 FTS 表
152        """
153        trigger_prefix = f"trigger_{table_name}" if fts_name is None else f"trigger_{fts_name}"
154        # 创建触发器, 在原表插入数据时, 同步从 FTS 表插入
155        if f"{trigger_prefix}_ai" not in triggers:
156            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ai' AFTER INSERT ON '{table_name}' BEGIN INSERT INTO '{fts_table}' (rowid, {fts_index_col}) VALUES (NEW.{fts_on_col}, NEW.{fts_index_col}); END;"
157            statements.append({"type": "execute", "stmt": {"sql": sql}})
158
159        # 创建触发器, 在原表删除数据时, 同步从 FTS 表删除
160        if f"{trigger_prefix}_ad" not in triggers:
161            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_ad' AFTER DELETE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col}; END;"
162            statements.append({"type": "execute", "stmt": {"sql": sql}})
163
164        # 创建触发器, 在原表更新数据时, 同步更新 FTS 表
165        # FTS5 的更新通常是先删除旧的, 再插入新的
166        if f"{trigger_prefix}_au" not in triggers:
167            sql = f"CREATE TRIGGER IF NOT EXISTS '{trigger_prefix}_au' AFTER UPDATE ON '{table_name}' BEGIN DELETE FROM '{fts_table}' WHERE rowid = OLD.{fts_on_col} AND OLD.{fts_index_col} <> NEW.{fts_index_col}; INSERT INTO '{fts_table}' (rowid, {fts_index_col}) SELECT NEW.{fts_on_col}, NEW.{fts_index_col} WHERE OLD.{fts_index_col} <> NEW.{fts_index_col}; END;"
168            statements.append({"type": "execute", "stmt": {"sql": sql}})
169        await turso_exec(
170            statements,
171            db_name=db_name,
172            username=username,
173            api_token=api_token,
174            group_token=group_token,
175            silent=silent,
176        )
177
178
179async def turso_list_tables(
180    db_name: str = "bennybot",
181    *,
182    username: str = DB.TURSO_USERNAME,
183    api_token: str = DB.TURSO_API_TOKEN,
184    group_token: str = DB.TURSO_GROUP_TOKEN,
185    silent: bool = False,
186) -> list[str]:
187    """List turso tables in a database."""
188    resp = await turso_exec(
189        [
190            {
191                "type": "execute",
192                "stmt": {"sql": "SELECT name FROM sqlite_master WHERE type='table';"},
193            }
194        ],
195        db_name=db_name,
196        username=username,
197        api_token=api_token,
198        group_token=group_token,
199        silent=silent,
200    )
201    return flatten(glom(resp, "results.0.response.result.rows.*.*.value", default=[]))
202
203
204async def turso_exec(
205    statements: list[dict],
206    *,
207    db_name: str = "bennybot",
208    username: str = DB.TURSO_USERNAME,
209    api_token: str = DB.TURSO_API_TOKEN,
210    group_token: str = DB.TURSO_GROUP_TOKEN,
211    retry: int = 0,
212    silent: bool = False,
213) -> dict:
214    """Exec turso statements."""
215    if not all([statements, db_name, username, api_token, group_token]):
216        return {}
217    db_url = await turso_db_url(db_name, username=username, api_token=api_token, silent=silent)
218    if not db_url:
219        return {}
220    headers = {"authorization": f"Bearer {group_token}", "content-type": "application/json"}
221
222    for stmt in statements:
223        if (sql := stmt.get("sql")) and not sql.endswith(";"):
224            stmt["sql"] += ";"
225
226    if statements[-1] != {"type": "close"}:
227        statements.append({"type": "close"})
228    if not silent:
229        logger.trace(f"Turso Exec: {statements}")
230
231    num_statements = len(statements)
232    resp = await hx_req(db_url, "POST", json_data={"requests": statements}, headers=headers, check_keys=["results"], proxy=PROXY.TURSO, max_retry=int(retry), silent=silent, timeout=600)
233    num_success = sum([1 for x in glom(resp, "results.*.type", default=[]) if x == "ok"])
234    if not silent:
235        rows = glom(resp, "results.0.response.result.rows", default=[])
236        log = f"Found {len(rows)} records in Turso."
237        log += f" Rows read: {glom(resp, 'results.0.response.result.rows_read', default=0)}"
238        log += f", write: {glom(resp, 'results.0.response.result.rows_written', default=0)}"
239        logger.success(log)
240    if num_statements != num_success:
241        error = "\n".join(glom(resp, "results.*.error.message", default=[]))
242        logger.error(f"Turso Exec: {num_statements} statements, {num_success} success.\n{error}")
243    return resp
244
245
246def insert_statement(table_name: str, records: dict, update_on_conflict: str = "") -> dict:
247    """Create a turso insert statement."""
248    SQL_TYPES = {"str": "text", "int": "integer", "float": "float", "nonetype": "null"}
249    keys = ", ".join(records)
250    values = ", ".join(["?" for _ in range(len(records))])
251    args = [{"type": SQL_TYPES[type(x).__name__.lower()], "value": str(x) if isinstance(x, int | float) else x} for x in records.values()]
252
253    sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values});"
254    if update_on_conflict:
255        updates = ", ".join([f"{k} = EXCLUDED.{k}" for k in records if k != update_on_conflict])
256        sql = f"INSERT INTO '{table_name}' ({keys}) VALUES ({values}) ON CONFLICT ({update_on_conflict}) DO UPDATE SET {updates};"
257    stmt = {"sql": sql}
258    if args:
259        stmt |= {"args": args}
260    return {"type": "execute", "stmt": stmt}
261
262
263def turso_parse_resp(resp: dict) -> list[dict]:
264    """Parse turso SELECT response."""
265
266    def correct_type(decltype: str, value: str):
267        if str(value) == "":
268            return ""
269        if decltype in ["INT", "INTEGER"]:
270            return int(value)
271        if decltype in ["FLOAT", "REAL"]:
272            return float(value)
273        return value
274
275    cols = glom(resp, "results.0.response.result.cols", default=[])
276    rows = glom(resp, "results.0.response.result.rows", default=[])
277    return [{col["name"]: correct_type(col["decltype"], x.get("value", "")) for x, col in zip(row, cols, strict=True)} for row in rows]