main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import os
5import re
6import string
7
8from glom import glom
9from loguru import logger
10from pyrogram.client import Client
11from pyrogram.errors import PeerIdInvalid
12from pyrogram.types import Chat, Message, User
13
14from config import DB, HISTORY, TID, cache, cutter
15from database.d1 import query_d1
16from database.turso import turso_exec, turso_parse_resp
17from messages.sender import send2tg
18from others.emoji import CTYPE_EMOJI
19from utils import find_url, myself, slim_cid, strings_list, to_int, true
20
21TURSO_KWARGS: dict = {
22 "db_name": HISTORY.TURSO_DATABASE,
23 "username": HISTORY.TURSO_USERNAME or DB.TURSO_USERNAME,
24 "api_token": HISTORY.TURSO_API_TOKEN or DB.TURSO_API_TOKEN,
25 "group_token": HISTORY.TURSO_GROUP_TOKEN or DB.TURSO_GROUP_TOKEN,
26}
27
28CHAT_COLUMNS = "cid INTEGER PRIMARY KEY, ctype TEXT, ctitle TEXT, chandle TEXT, tablename TEXT, tags TEXT"
29USER_COLUMNS = "ctitle TEXT, full_name TEXT, handle TEXT, tags TEXT, name TEXT, uid INTEGER, cid INTEGER, id INTEGER PRIMARY KEY"
30USER_INDEXES = ["uid", "cid"]
31
32MSG_COLUMNS = "mid INTEGER PRIMARY KEY, mtype TEXT, time TEXT NOT NULL, fullname TEXT, content TEXT, filename TEXT, urls TEXT, reply INTEGER, mime TEXT, user TEXT, handle TEXT, uid INTEGER, gid INTEGER, segmented TEXT" # fmt: off
33MSG_INDEXES = ["time", "user", "uid", "handle"]
34
35
36@cache.memoize(ttl=0)
37def check_save_history(ctype: str, cid: int | str) -> bool:
38 # ruff: noqa: SIM103
39 cid = slim_cid(cid)
40 if true(os.getenv(f"HISTORY_IGNORE_{cid}")):
41 return False
42 if true(os.getenv(f"HISTORY_INCLUDE_{cid}")):
43 return True
44 if cid in strings_list(HISTORY.IGNORE_CHATS):
45 return False
46 if cid in strings_list(HISTORY.INCLUDE_CHATS):
47 return True
48 if ctype == "PRIVATE":
49 if str(HISTORY.INCLUDE_PRIVATES).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_PRIVATES):
50 return True
51 return False
52 if ctype == "BOT":
53 if str(HISTORY.INCLUDE_BOTS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_BOTS):
54 return True
55 return False
56 if ctype in ["GROUP", "SUPERGROUP"]:
57 if str(HISTORY.INCLUDE_GROUPS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_GROUPS):
58 return True
59 return False
60 if ctype == "CHANNEL":
61 if str(HISTORY.INCLUDE_CHANNELS).lower() == "all" or cid in strings_list(HISTORY.INCLUDE_CHANNELS):
62 return True
63 return False
64 return False
65
66
67@cache.memoize(ttl=0)
68def can_delete_history(cid: int | str, uid: int | str) -> bool:
69 # ruff: noqa: SIM103
70 cid = slim_cid(cid)
71 if true(os.getenv(f"HISTORY_CAN_DEL_C{cid}")):
72 return True
73 if true(os.getenv(f"HISTORY_CAN_DEL_U{uid}")):
74 return True
75 if true(os.getenv(f"HISTORY_CAN_DEL_C{cid}_U{uid}")):
76 return True
77 return False
78
79
80def fine_grained_check(info: dict) -> bool:
81 """由于有些对话不需要保存所有类型的聊天历史, 这里检查是否需要跳过.
82
83 这种细粒度的检查, 仅支持通过环境变量设置.
84 目前支持:
85 HISTORY_{cid}_MUST_MTYPE: 必须为指定的消息类型, 可以为多个类型, 用逗号分隔
86 HISTORY_{cid}_MUST_HAVE_TEXT: 必须有文字的消息
87 HISTORY_{cid}_SKIP_URL: 跳过包含链接的消息
88 HISTORY_{cid}_SKIP_KEYWORDS: 跳过包含关键词的消息 (其中关键词为逗号分隔的字符串)
89 例如: 对于`chat_id = 1234` 的对话, 不需要保存没有文字的消息
90 """
91 # ruff: noqa: SIM103
92 cid = slim_cid(info["cid"])
93 if (mtype := os.getenv(f"HISTORY_{cid}_MUST_MTYPE")) and info["mtype"].lower() not in mtype.lower():
94 return False
95 if true(os.getenv(f"HISTORY_{cid}_MUST_HAVE_TEXT")) and not info["text"]:
96 return False
97 if true(os.getenv(f"HISTORY_{cid}_MUST_HAVE_URL")) and not (find_url(info["text"]) or info.get("entity_urls")):
98 return False
99 if true(os.getenv(f"HISTORY_{cid}_SKIP_URL")) and (find_url(info["text"]) or info.get("entity_urls")):
100 return False
101 if any(x in info["text"] for x in strings_list(os.getenv(f"HISTORY_{cid}_SKIP_KEYWORDS"))):
102 return False
103 return True
104
105
106async def get_chat(client: Client, chat_id: int | str) -> Chat:
107 if cache.get(f"chat-info-{slim_cid(chat_id)}"):
108 return cache.get(f"chat-info-{slim_cid(chat_id)}")
109 chat = Chat(id=0) # default
110 if str(chat_id) == "0":
111 return chat
112 try:
113 chat = await client.get_chat(to_int(chat_id))
114 except PeerIdInvalid:
115 with contextlib.suppress(Exception):
116 chat = await client.get_chat(to_int(f"-100{slim_cid(chat_id)}"))
117 except Exception:
118 logger.warning(f"Failed to get chat info for {chat_id}")
119 cache.set(f"chat-info-{slim_cid(chat_id)}", chat, ttl=3600) # cache for 1 hour
120 return chat
121
122
123async def list_chat_ids(client: Client, message: Message, engine: str = "turso"):
124 """List chat ids from turso table `chatinfo`.
125
126 One Turso database may be read by multiple Telegram accounts, we can use tags to filter by account
127 For example,
128 tags:
129 {my_uid}_SKIP_LIST -> skip list of `my_uid` account
130 SKIP_LIST_IN_{chatid} -> skip list in this chat_id
131 ONLY_LIST_IN_{chatid} -> only list in this chat_id
132 """
133 if engine.lower() == "turso":
134 resp = await turso_exec([{"type": "execute", "stmt": {"sql": "SELECT * FROM chatinfo;"}}], silent=True, retry=2, **TURSO_KWARGS)
135 chats = turso_parse_resp(resp)
136 else:
137 resp = await query_d1("SELECT * FROM chatinfo;", db_name=HISTORY.D1_DATABASE, silent=True)
138 chats = glom(resp, "result.0.results", default=[])
139
140 me = await myself(client)
141 cid = slim_cid(message.chat.id)
142 msg = ""
143 for x in sorted(chats, key=lambda x: x["ctype"]):
144 tags = strings_list(x.get("tags", ""))
145 if "ONLY_LIST_IN_" in x.get("tags", "") and f"ONLY_LIST_IN_{cid}" not in tags:
146 continue
147 if "SKIP_LIST_IN_" in x.get("tags", "") and f"SKIP_LIST_IN_{cid}" in tags:
148 continue
149 if f"{me.id}_SKIP_LIST" in tags:
150 continue
151 msg += f"`/history #{x['cid']}` {CTYPE_EMOJI[x['ctype']]}: {x['ctitle']}\n"
152 await send2tg(client, message, texts=msg)
153
154
155def is_admin(uid: int) -> bool:
156 return any(slim_cid(admin) == slim_cid(uid) for admin in strings_list(TID.HISTORY_ADMIN))
157
158
159@cache.memoize(ttl=10)
160async def get_user_from_chat(client: Client, uid: int | str, cid: int | str) -> User:
161 user = User(id=0)
162 if any(char not in f"{string.ascii_letters}_{string.digits}" for char in str(uid)):
163 return user
164 try: # get chat member directly
165 chat_member = await client.get_chat_member(to_int(cid), to_int(uid))
166 user = chat_member.user
167 except Exception:
168 with contextlib.suppress(Exception): # get chat member from chat members
169 async for member in client.get_chat_members(to_int(cid)): # type: ignore
170 if member.user.id == to_int(uid) or member.user.username == to_int(uid):
171 user = member.user
172 break
173 return user
174
175
176def keyword_query(keyword: str) -> str:
177 """Generate search query based on keyword."""
178 # ruff: noqa: RUF001
179 punctuation = "!#$&*+,-./:;<=>?@[\\]^_`{|}~" + ",。?!:;“”‘’《》"
180 for punc in punctuation: # remove pucntuation
181 keyword = keyword.replace(punc, " ")
182 keyword = keyword.replace("(", "(")
183 keyword = keyword.replace(")", ")")
184 # remove consecutive whitespace
185 while " " in keyword:
186 keyword = keyword.replace(" ", " ")
187 # remove leading and trailing whitespace
188 keyword = keyword.strip()
189 segmented = [x for x in cutter.cutword(keyword) if x not in string.whitespace]
190 final = []
191 length = len(segmented)
192 for i, word in enumerate(segmented):
193 if word.strip() in ["OR", "AND", "NOT", "(", ")"]:
194 final.append(word)
195 continue
196 final.append(f'"{word}"')
197 if i != length - 1 and segmented[i + 1].strip() not in ["OR", "AND", "NOT", ")"]:
198 final.append("AND")
199 return f"""FTS.segmented MATCH '{" ".join(final)}'"""
200
201
202def filter_response(resp: list[dict], keyword: str) -> list[dict]:
203 """Filter response by keyword."""
204 filtered = [row for row in resp if row["content"]] # remove empty keywords messages
205 if any(x in keyword for x in ["OR", "AND", "NOT", "(", ")"]): # keyword is a SQL query
206 return filtered
207 if not keyword:
208 return filtered
209 # remove consecutive whitespace
210 while " " in keyword:
211 keyword = keyword.replace(" ", " ")
212 # remove leading and trailing whitespace
213 keyword = keyword.strip()
214
215 re_keyword = keyword.replace(" ", "(.*?)")
216 pattern = re.compile(rf"{re_keyword}", flags=re.IGNORECASE)
217 return [row for row in filtered if pattern.search(row["content"])]