main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5import contextlib
6import hashlib
7import time
8from pathlib import Path
9from typing import TYPE_CHECKING, Literal
10
11from anthropic import AsyncAnthropic
12from glom import glom
13from google import genai
14from google.genai.types import FileState, Part, UploadFileConfig
15from loguru import logger
16from openai import AsyncOpenAI, DefaultAsyncHttpxClient
17from pyrogram.client import Client
18from pyrogram.types import Message
19
20from ai.utils import BOT_TIPS, clean_context
21from asr.utils import GEMINI_AUDIO_EXT, downsampe_audio
22from config import AI, DOWNLOAD_DIR
23from database.r2 import head_cf_r2, set_cf_r2
24from messages.parser import get_thread_id, parse_msg
25from utils import convert_md, read_text
26
27if TYPE_CHECKING:
28 from io import BytesIO
29
30
31async def base64_media(client: Client, message: Message) -> dict:
32 data: BytesIO = await client.download_media(message, in_memory=True) # type: ignore
33 logger.debug(f"Downloaded message media: {data.name}")
34
35 ext = Path(data.name).suffix.removeprefix(".").replace("jpg", "jpeg")
36
37 # image, video
38 b64_encoding = base64.b64encode(data.getvalue()).decode("utf-8")
39
40 # text document
41 value = ""
42 with contextlib.suppress(Exception):
43 value = data.getvalue().decode("utf-8")
44 return {
45 "ext": ext,
46 "base64": b64_encoding,
47 "value": value,
48 }
49
50
51async def get_openai_completion_contexts(client: Client, message: Message) -> list[dict]:
52 """Generate OpenAI chat completion contexts."""
53 messages = [message]
54 while message.reply_to_message:
55 message = message.reply_to_message
56 messages.append(message)
57 messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
58 return [ctx for msg in messages if (ctx := await single_openai_chat_context(client, msg))]
59
60
61async def single_openai_chat_context(client: Client, message: Message) -> dict:
62 """Generate OpenAI chat completion contexts for a single message.
63
64 Returns:
65 {
66 "role": "user or assistant",
67 "content": [],
68 }
69 """
70 info = parse_msg(message, silent=True)
71 role = "assistant" if BOT_TIPS in info["text"] else "user"
72
73 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
74 return {}
75
76 extra_txt_extensions = [".sh", ".json", ".xml"] # treat these as txt file
77 extra_markdown_extensions = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
78
79 messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
80 contexts = []
81 for msg in messages:
82 info = parse_msg(msg, silent=True)
83 sender = info["fwd_full_name"] or info["full_name"]
84 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
85 try:
86 if info["mtype"] == "photo":
87 res = await base64_media(client, msg)
88 contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
89 elif info["mtype"] == "document":
90 if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
91 fpath: str = await client.download_media(msg, media_path) # type: ignore
92 contexts.append(
93 {
94 "type": "text",
95 "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
96 }
97 )
98 elif Path(info["file_name"]).suffix in extra_markdown_extensions:
99 fpath: str = await client.download_media(msg, media_path) # type: ignore
100 text = convert_md(fpath)
101 Path(fpath).unlink(missing_ok=True)
102 contexts.append(
103 {
104 "type": "text",
105 "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
106 }
107 )
108 # user message has entity urls, use full html
109 clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
110 if not clean_texts:
111 continue
112 if role == "user" and sender: # noqa: SIM108
113 texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
114 else:
115 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
116 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
117 contexts.append({"type": "text", "text": texts})
118 except Exception as e:
119 logger.warning(f"Download media from message failed: {e}")
120 continue
121 return {"role": role, "content": contexts} if contexts else {}
122
123
124async def get_openai_response_contexts(client: Client, message: Message, openai_params: dict) -> tuple[str, list[dict]]:
125 """Generate OpenAI response contexts.
126
127 Returns:
128 previous_response_id, contexts
129 """
130
131 async def get_previous_response_id(msg: Message) -> str:
132 """Get previous response id from message.
133
134 Returns:
135 previous_response_id: str
136 """
137 api_key = openai_params["api_key"]
138 model_id = openai_params["model_id"]
139 cache_day = openai_params["cache_day"]
140 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
141 tid = get_thread_id(msg)
142 resp = await head_cf_r2(f"TTL/{cache_day}d/OpenAI/{model_id}/{key_hash}/{msg.chat.id}/{msg.id}{'/' + str(tid) if tid else ''}")
143 return glom(resp, "Metadata.response_id", default="") or ""
144
145 previous_response_id = ""
146 messages = [message]
147 while message.reply_to_message and not previous_response_id:
148 message = message.reply_to_message
149 if pid := await get_previous_response_id(message):
150 previous_response_id = pid
151 break
152 messages.append(message)
153 messages.reverse() # old to new
154 return previous_response_id, [ctx for msg in messages if (ctx := await single_openai_response_context(client, msg, openai_params))]
155
156
157async def single_openai_response_context(client: Client, message: Message, openai_params: dict) -> dict:
158 """Generate OpenAI response contexts for a single message.
159
160 Returns:
161 {
162 "role": "user or assistant",
163 "content": [],
164 }
165 """
166 info = parse_msg(message, silent=True)
167 role = "assistant" if BOT_TIPS in info["text"] else "user"
168
169 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
170 return {}
171
172 extra_txt_extensions = [".sh", ".json", ".xml"] # treat these as txt file
173 extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
174
175 messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
176 contexts = []
177 for msg in messages:
178 info = parse_msg(msg, silent=True)
179 sender = info["fwd_full_name"] or info["full_name"]
180 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
181 media_send_as = openai_params.get("openai_media_send_as", "file_id")
182 file_id = ""
183 try:
184 if info["mtype"] == "photo":
185 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
186 contexts.append({"type": "input_image", "file_id": file_id, "detail": "high"})
187 if not file_id:
188 res = await base64_media(client, msg)
189 contexts.append({"type": "input_image", "image_url": f"data:image/{res['ext']};base64,{res['base64']}", "detail": "high"})
190
191 elif info["mtype"] == "video":
192 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
193 contexts.append({"type": "input_video", "file_id": file_id})
194 if not file_id:
195 res = await base64_media(client, msg)
196 contexts.append({"type": "input_video", "image_url": f"data:video/{res['ext']};base64,{res['base64']}"})
197
198 elif info["mtype"] == "document":
199 if info["mime_type"] == "application/pdf":
200 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
201 contexts.append({"type": "input_file", "file_id": file_id})
202 if not file_id:
203 res = await base64_media(client, msg)
204 contexts.append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
205
206 elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
207 fpath: str = await client.download_media(msg, media_path) # type: ignore
208 contexts.append(
209 {
210 "type": "input_text",
211 "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
212 }
213 )
214 elif Path(info["file_name"]).suffix in extra_markdown_extensions:
215 fpath: str = await client.download_media(msg, media_path) # type: ignore
216 text = convert_md(fpath)
217 Path(fpath).unlink(missing_ok=True)
218 contexts.append(
219 {
220 "type": "input_text",
221 "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
222 }
223 )
224 # user message has entity urls, use full html
225 clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
226 if not clean_texts:
227 continue
228 if role == "user" and sender: # noqa: SIM108
229 texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
230 else:
231 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
232 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
233 contexts.append({"type": "input_text", "text": texts})
234 except Exception as e:
235 logger.warning(f"Download media from message failed: {e}")
236 continue
237 return {"role": role, "content": contexts} if contexts else {}
238
239
240async def get_openai_file_id(client: Client, message: Message, openai_params: dict, mtype: str) -> str:
241 def get_real_baseurl() -> str:
242 base_url = str(openai_params["base_url"]) or ""
243 default_headers = openai_params.get("default_headers", {})
244 default_headers = {k.lower(): v for k, v in default_headers.items()}
245 if base_url.startswith("https://gateway.helicone.ai"):
246 helicone_target_url = default_headers.get("helicone-target-url") or ""
247 return base_url.replace("https://gateway.helicone.ai", helicone_target_url.rstrip("/"))
248 return base_url
249
250 if mtype not in ["photo", "video", "document"]:
251 return ""
252 if not openai_params["allow_image"] and mtype == "photo":
253 return ""
254 if not openai_params["allow_video"] and mtype == "video":
255 return ""
256 if not openai_params["allow_file"] and mtype == "document":
257 return ""
258
259 cache_day = openai_params.get("cache_day", 30)
260 api_key = openai_params["api_key"]
261 model_id = openai_params["model_id"]
262 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
263 tid = get_thread_id(message)
264 r2_key = f"TTL/{cache_day}d/OpenAI/{model_id}/{key_hash}/{message.chat.id}/{message.id}{'/' + str(tid) if tid else ''}-file_id"
265 r2 = await head_cf_r2(r2_key)
266 if file_id := glom(r2, "Metadata.file_id", default=""):
267 return file_id
268
269 openai = AsyncOpenAI(
270 base_url=get_real_baseurl(),
271 api_key=api_key,
272 http_client=DefaultAsyncHttpxClient(proxy=openai_params["proxy"]) if openai_params.get("proxy") else None,
273 )
274 fpath: str = await client.download_media(message) # type: ignore
275 extra_body = {"expire_at": int(time.time()) + 3600 * 24 * cache_day}
276
277 preprocess_configs = {}
278 if message.video:
279 duration = glom(message, "video.duration", default=1e8)
280 ratio = int(duration // 300)
281 fps = ratio * 0.5
282 if fps < 0.5:
283 fps = 0.5
284 elif fps > 5.0:
285 fps = 5.0
286 preprocess_configs = {"video": {"fps": fps, "model": openai_params["model_id"]}}
287 if preprocess_configs:
288 extra_body["preprocess_configs"] = preprocess_configs
289 try:
290 resp = await openai.files.create(file=Path(fpath), purpose="user_data", extra_body=extra_body)
291 while resp.status == "processing":
292 logger.trace(f"Upload media to OpenAI processing: {resp.model_dump()}")
293 await asyncio.sleep(3)
294 resp = await openai.files.retrieve(file_id=resp.id)
295 if resp.status == "active":
296 Path(fpath).unlink(missing_ok=True)
297 await set_cf_r2(r2_key, data=resp.model_dump(), metadata={"file_id": resp.id})
298 return resp.id
299 logger.error(f"Upload media to OpenAI failed: {resp.model_dump()}")
300 except Exception as e:
301 logger.error(f"Upload media to OpenAI failed: {e}")
302 return ""
303
304
305async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Client) -> list[dict]:
306 """Generate Gemini contexts from old to new.
307
308 Returns:
309 contexts: list[dict]
310 """
311 ctx_messages = [message]
312 while message.reply_to_message:
313 message = message.reply_to_message
314 ctx_messages.append(message)
315 ctx_messages = ctx_messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
316 contexts = []
317 for m in ctx_messages:
318 info = parse_msg(m, silent=True)
319 role = "model" if BOT_TIPS in info["text"] else "user"
320 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
321 continue
322 # gemini has built-in support for these extensions
323 gemini_extensions = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
324 # gemini has built-in support for these mime types
325 gemini_mime_types = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
326 txt_extensions = [".txt", ".js", ".py", ".md", ".sh", ".json"] # treat these as txt file
327 extra_markdown_extensions = [".docx", ".pptx", ".xls", ".xlsx", ".epub"] # convert to markdown
328 group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
329 parts = []
330 for msg in group_messages:
331 info = parse_msg(msg, silent=True)
332 sender = info["fwd_full_name"] or info["full_name"]
333 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
334 try:
335 if info["mtype"] in ["video", "photo", "audio", "voice"] or info["mime_type"] in gemini_mime_types or any(info["file_name"].endswith(ext) for ext in gemini_extensions):
336 fpath: str = await client.download_media(msg, media_path) # type: ignore
337 if info["mtype"] in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
338 audio_path = await downsampe_audio(fpath)
339 fpath = audio_path.as_posix()
340 upload = await gemini.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=info["file_name"] or f"send from {sender}"))
341 while upload.state == FileState.PROCESSING:
342 logger.trace("Waiting for upload to complete...")
343 await asyncio.sleep(1)
344 upload = await gemini.aio.files.get(name=upload.name) # type: ignore
345 if upload.state == FileState.ACTIVE and upload.uri:
346 parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
347 Path(fpath).unlink(missing_ok=True)
348 elif info["mtype"] == "document":
349 if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
350 fpath: str = await client.download_media(msg, media_path) # type: ignore
351 parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
352 if Path(info["file_name"]).suffix in extra_markdown_extensions:
353 fpath: str = await client.download_media(msg, media_path) # type: ignore
354 text = convert_md(fpath)
355 Path(fpath).unlink(missing_ok=True)
356 parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
357 clean_texts = clean_context(info["text"])
358 if not clean_texts:
359 continue
360 if role == "user" and sender: # noqa: SIM108
361 texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
362 else:
363 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
364 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
365 parts.append(Part.from_text(text=texts))
366 except Exception as e:
367 logger.warning(f"Download media from message failed: {e}")
368 continue
369 if parts:
370 contexts.append({"role": role, "parts": parts})
371 return contexts
372
373
374async def get_anthropic_contexts(client: Client, message: Message, **kwargs) -> list[dict]:
375 """Generate Anthropic contexts."""
376 messages = [message]
377 while message.reply_to_message:
378 message = message.reply_to_message
379 messages.append(message)
380 messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
381 return [ctx for msg in messages if (ctx := await single_anthropic_context(client, msg, **kwargs))]
382
383
384async def single_anthropic_context(
385 client: Client,
386 message: Message,
387 anthropic: AsyncAnthropic,
388 cache_hour: int = 0,
389 media_send_as: Literal["base64", "file_id"] = "file_id",
390) -> dict:
391 """Generate Anthropic contexts for a single message.
392
393 Returns:
394 {
395 "role": "user or assistant",
396 "content": [],
397 }
398 """
399 info = parse_msg(message, silent=True)
400 role = "assistant" if BOT_TIPS in info["text"] else "user"
401
402 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
403 return {}
404
405 extra_txt_extensions = [".sh", ".json", ".xml"] # treat these as txt file
406 extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
407
408 messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
409 contexts = []
410 for msg in messages:
411 info = parse_msg(msg, silent=True)
412 sender = info["fwd_full_name"] or info["full_name"]
413 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
414 file_id = ""
415 try:
416 if info["mtype"] == "photo":
417 if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
418 contexts.append({"type": "image", "source": {"type": "file", "file_id": file_id}})
419 if not file_id:
420 res = await base64_media(client, msg)
421 contexts.append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
422
423 elif info["mtype"] == "document":
424 if info["mime_type"] == "application/pdf":
425 if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
426 contexts.append({"type": "document", "source": {"type": "file", "file_id": file_id}})
427 if not file_id:
428 res = await base64_media(client, msg)
429 contexts.append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
430
431 elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
432 fpath: str = await client.download_media(msg, media_path) # type: ignore
433 contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
434
435 elif Path(info["file_name"]).suffix in extra_markdown_extensions:
436 fpath: str = await client.download_media(msg, media_path) # type: ignore
437 text = convert_md(fpath)
438 Path(fpath).unlink(missing_ok=True)
439 contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"})
440 # user message has entity urls, use full html
441 clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
442 if not clean_texts:
443 continue
444 if role == "user" and sender: # noqa: SIM108
445 texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
446 else:
447 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
448 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
449 contexts.append({"type": "text", "text": texts})
450 except Exception as e:
451 logger.warning(f"Download media from message failed: {e}")
452 continue
453 return {"role": role, "content": contexts} if contexts else {}
454
455
456async def get_anthropic_file_id(client: Client, message: Message, anthropic: AsyncAnthropic, cache_hour: int) -> str:
457 api_key: str = anthropic.api_key # ty:ignore[invalid-assignment]
458 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
459 tid = get_thread_id(message)
460 cache_hour = cache_hour or 12
461 r2_key = f"TTL/{cache_hour}h/Anthropic/{key_hash}/{message.chat.id}/{message.id}{'/' + str(tid) if tid else ''}-file_id"
462 r2 = await head_cf_r2(r2_key)
463 if file_id := glom(r2, "Metadata.file_id", default=""):
464 return file_id
465 fpath: str = await client.download_media(message) # type: ignore
466 try:
467 resp = await anthropic.beta.files.upload(file=Path(fpath))
468 if glom(resp, "id", default=""):
469 return resp.id
470 logger.error(f"Upload media to Anthropic failed: {resp.model_dump()}")
471 except Exception as e:
472 logger.error(f"Upload media to Anthropic failed: {e}")
473 return ""