main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4import os
  5import re
  6import string
  7
  8from glom import glom
  9from loguru import logger
 10from pyrogram.client import Client
 11from pyrogram.errors import PeerIdInvalid
 12from pyrogram.types import Chat, Message, User
 13
 14from config import DB, HISTORY, TID, cache, cutter
 15from database.d1 import query_d1
 16from database.turso import turso_exec, turso_parse_resp
 17from messages.sender import send2tg
 18from others.emoji import CTYPE_EMOJI
 19from utils import find_url, myself, slim_cid, strings_list, to_int, true
 20
 21TURSO_KWARGS: dict = {
 22    "db_name": HISTORY.TURSO_DATABASE,
 23    "username": HISTORY.TURSO_USERNAME or DB.TURSO_USERNAME,
 24    "api_token": HISTORY.TURSO_API_TOKEN or DB.TURSO_API_TOKEN,
 25    "group_token": HISTORY.TURSO_GROUP_TOKEN or DB.TURSO_GROUP_TOKEN,
 26}
 27
 28CHAT_COLUMNS = "cid INTEGER PRIMARY KEY, ctype TEXT, ctitle TEXT, chandle TEXT, tablename TEXT, tags TEXT"
 29USER_COLUMNS = "ctitle TEXT, full_name TEXT, handle TEXT, tags TEXT, name TEXT, uid INTEGER, cid INTEGER, id INTEGER PRIMARY KEY"
 30USER_INDEXES = ["uid", "cid"]
 31
 32MSG_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, fullname TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, user TEXT, handle TEXT, uid INTEGER, gid INTEGER, segmented TEXT"  # fmt: off
 33MSG_INDEXES = ["time", "user", "uid", "handle"]
 34
 35
 36@cache.memoize(ttl=0)
 37def check_save_history(ctype: str, cid: int | str) -> bool:
 38    # ruff: noqa: SIM103
 39    cid = slim_cid(cid)
 40    if true(os.getenv(f"HISTORY_IGNORE_{cid}")):
 41        return False
 42    if true(os.getenv(f"HISTORY_INCLUDE_{cid}")):
 43        return True
 44    if cid in strings_list(HISTORY.IGNORE_CHATS):
 45        return False
 46    if cid in strings_list(HISTORY.INCLUDE_CHATS):
 47        return True
 48    if ctype == "PRIVATE":
 49        if str(HISTORY.INCLUDE_PRIVATES).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_PRIVATES):
 50            return True
 51        return False
 52    if ctype == "BOT":
 53        if str(HISTORY.INCLUDE_BOTS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_BOTS):
 54            return True
 55        return False
 56    if ctype in ["GROUP", "SUPERGROUP"]:
 57        if str(HISTORY.INCLUDE_GROUPS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_GROUPS):
 58            return True
 59        return False
 60    if ctype == "CHANNEL":
 61        if str(HISTORY.INCLUDE_CHANNELS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_CHANNELS):
 62            return True
 63        return False
 64    return False
 65
 66
 67@cache.memoize(ttl=0)
 68def can_delete_history(cid: int | str, uid: int | str) -> bool:
 69    # ruff: noqa: SIM103
 70    cid = slim_cid(cid)
 71    if true(os.getenv(f"HISTORY_CAN_DEL_C{cid}")):
 72        return True
 73    if true(os.getenv(f"HISTORY_CAN_DEL_U{uid}")):
 74        return True
 75    if true(os.getenv(f"HISTORY_CAN_DEL_C{cid}_U{uid}")):
 76        return True
 77    return False
 78
 79
 80def fine_grained_check(info: dict) -> bool:
 81    """由于有些对话不需要保存所有类型的聊天历史, 这里检查是否需要跳过.
 82
 83    这种细粒度的检查, 仅支持通过环境变量设置.
 84    目前支持:
 85        HISTORY_{cid}_MUST_MTYPE: 必须为指定的消息类型, 可以为多个类型, 用逗号分隔
 86        HISTORY_{cid}_MUST_HAVE_TEXT:  必须有文字的消息
 87        HISTORY_{cid}_SKIP_URL: 跳过包含链接的消息
 88        HISTORY_{cid}_SKIP_KEYWORDS: 跳过包含关键词的消息 (其中关键词为逗号分隔的字符串)
 89    例如: 对于`chat_id = 1234` 的对话, 不需要保存没有文字的消息
 90    """
 91    # ruff: noqa: SIM103
 92    cid = slim_cid(info["cid"])
 93    if (mtype := os.getenv(f"HISTORY_{cid}_MUST_MTYPE")) and info["mtype"].lower() not in mtype.lower():
 94        return False
 95    if true(os.getenv(f"HISTORY_{cid}_MUST_HAVE_TEXT")) and not info["text"]:
 96        return False
 97    if true(os.getenv(f"HISTORY_{cid}_MUST_HAVE_URL")) and not (find_url(info["text"]) or info.get("entity_urls")):
 98        return False
 99    if true(os.getenv(f"HISTORY_{cid}_SKIP_URL")) and (find_url(info["text"]) or info.get("entity_urls")):
100        return False
101    if any(x in info["text"] for x in strings_list(os.getenv(f"HISTORY_{cid}_SKIP_KEYWORDS"))):
102        return False
103    return True
104
105
106async def get_chat(client: Client, chat_id: int | str) -> Chat:
107    if cache.get(f"chat-info-{slim_cid(chat_id)}"):
108        return cache.get(f"chat-info-{slim_cid(chat_id)}")
109    chat = Chat(id=0)  # default
110    if str(chat_id) == "0":
111        return chat
112    try:
113        chat = await client.get_chat(to_int(chat_id))
114    except PeerIdInvalid:
115        with contextlib.suppress(Exception):
116            chat = await client.get_chat(to_int(f"-100{slim_cid(chat_id)}"))
117    except Exception:
118        logger.warning(f"Failed to get chat info for {chat_id}")
119    cache.set(f"chat-info-{slim_cid(chat_id)}", chat, ttl=3600)  # cache for 1 hour
120    return chat
121
122
123async def list_chat_ids(client: Client, message: Message, engine: str = "turso"):
124    """List chat ids from turso table `chatinfo`.
125
126    One Turso database may be read by multiple Telegram accounts, we can use tags to filter by account
127    For example,
128    tags:
129        {my_uid}_SKIP_LIST -> skip list of `my_uid` account
130        SKIP_LIST_IN_{chatid}  -> skip list in this chat_id
131        ONLY_LIST_IN_{chatid}  -> only list in this chat_id
132    """
133    if engine.lower() == "turso":
134        resp = await turso_exec([{"type": "execute", "stmt": {"sql": "SELECT * FROM chatinfo;"}}], silent=True, retry=2, **TURSO_KWARGS)
135        chats = turso_parse_resp(resp)
136    else:
137        resp = await query_d1("SELECT * FROM chatinfo;", db_name=HISTORY.D1_DATABASE, silent=True)
138        chats = glom(resp, "result.0.results", default=[])
139
140    me = await myself(client)
141    cid = slim_cid(message.chat.id)
142    msg = ""
143    for x in sorted(chats, key=lambda x: x["ctype"]):
144        tags = strings_list(x.get("tags", ""))
145        if "ONLY_LIST_IN_" in x.get("tags", "") and f"ONLY_LIST_IN_{cid}" not in tags:
146            continue
147        if "SKIP_LIST_IN_" in x.get("tags", "") and f"SKIP_LIST_IN_{cid}" in tags:
148            continue
149        if f"{me.id}_SKIP_LIST" in tags:
150            continue
151        msg += f"`/history #{x['cid']}` {CTYPE_EMOJI[x['ctype']]}: {x['ctitle']}\n"
152    await send2tg(client, message, texts=msg)
153
154
155def is_admin(uid: int) -> bool:
156    return any(slim_cid(admin) == slim_cid(uid) for admin in strings_list(TID.HISTORY_ADMIN))
157
158
159@cache.memoize(ttl=10)
160async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) -> User:
161    user = User(id=0)
162    if any(char not in f"{string.ascii_letters}_{string.digits}" for char in str(uid)):
163        return user
164    try:  # get chat member directly
165        chat_member = await client.get_chat_member(to_int(cid), to_int(uid))
166        user = chat_member.user
167    except Exception:
168        with contextlib.suppress(Exception):  # get chat member from chat members
169            async for member in client.get_chat_members(to_int(cid)):  # type: ignore
170                if member.user.id == to_int(uid) or member.user.username == to_int(uid):
171                    user = member.user
172                    break
173    return user
174
175
176def keyword_query(keyword: str) -> str:
177    """Generate search query based on keyword."""
178    # ruff: noqa: RUF001
179    punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
180    for punc in punctuation:  # remove pucntuation
181        keyword = keyword.replace(punc, " ")
182    keyword = keyword.replace("", "(")
183    keyword = keyword.replace("", ")")
184    # remove consecutive whitespace
185    while "  " in keyword:
186        keyword = keyword.replace("  ", " ")
187    # remove leading and trailing whitespace
188    keyword = keyword.strip()
189    segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
190    final = []
191    length = len(segmented)
192    for i, word in enumerate(segmented):
193        if word.strip() in ["OR", "AND", "NOT", "(", ")"]:
194            final.append(word)
195            continue
196        final.append(f'"{word}"')
197        if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
198            final.append("AND")
199    return f"""FTS.segmented MATCH '{" ".join(final)}'"""
200
201
202def filter_response(resp: list[dict], keyword: str) -> list[dict]:
203    """Filter response by keyword."""
204    filtered = [row for row in resp if row["content"]]  # remove empty keywords messages
205    if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]):  # keyword is a SQL query
206        return filtered
207    if not keyword:
208        return filtered
209    # remove consecutive whitespace
210    while "  " in keyword:
211        keyword = keyword.replace("  ", " ")
212    # remove leading and trailing whitespace
213    keyword = keyword.strip()
214
215    re_keyword = keyword.replace(" ", "(.*?)")
216    pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
217    return [row for row in filtered if pattern.search(row["content"])]