main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5import contextlib
6import hashlib
7import json
8import mimetypes
9from pathlib import Path
10from typing import TYPE_CHECKING, Literal
11
12import anyio
13from anthropic import AsyncAnthropic
14from glom import Coalesce, glom
15from google import genai
16from google.genai.types import FileData, FileState, HttpOptions, Part, UploadFileConfig
17from loguru import logger
18from openai import AsyncOpenAI, DefaultAsyncHttpxClient
19from pyrogram.client import Client
20from pyrogram.types import Chat, Document, Message
21
22from ai.utils import BOT_TIPS, clean_context
23from asr.utils import GEMINI_AUDIO_EXT, downsampe_audio
24from config import AI, DOWNLOAD_DIR, PROXY, TID
25from database.r2 import head_cf_r2, set_cf_r2
26from messages.parser import parse_msg
27from others.download_external import AUDIO_FORMAT, VIDEO_FORMAT
28from utils import convert2md, digest, guess_mime, read_text
29
30if TYPE_CHECKING:
31 from io import BytesIO
32
33TXT_EXT = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
34MARKDOWN_EXT = [".html", ".docx", ".pptx", ".xls", ".xlsx", ".epub"] # convert to markdown
35# gemini has built-in support for these extensions
36GEMINI_EXT = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
37# gemini has built-in support for these mime types
38GEMINI_MIME = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
39
40
41async def base64_media(client: Client, message: Message) -> dict:
42 data: BytesIO = await client.download_media(message, in_memory=True) # type: ignore
43 logger.debug(f"Downloaded message media: {data.name}")
44
45 ext = Path(data.name).suffix.removeprefix(".").replace("jpg", "jpeg")
46
47 # image, video
48 b64_encoding = base64.b64encode(data.getvalue()).decode("utf-8")
49
50 # text document
51 value = ""
52 with contextlib.suppress(Exception):
53 value = data.getvalue().decode("utf-8")
54 return {
55 "ext": ext,
56 "base64": b64_encoding,
57 "value": value,
58 }
59
60
61async def get_openai_completion_contexts(client: Client, message: Message, params: dict) -> list[dict]:
62 """Generate OpenAI chat completion contexts."""
63 chains = await full_chain_contexts(client, message, order="asc") # old to new
64 add_sender = params.get("add_sender")
65 if add_sender is None:
66 add_sender = is_multi_user_chat(chains)
67 messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
68 contexts: list = []
69
70 for msg in messages:
71 info = parse_msg(msg, silent=True)
72 role = "assistant" if BOT_TIPS in info["text"] else "user"
73
74 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
75 continue
76
77 context = {"role": role, "content": []}
78 sender = info["fwd_full_name"] or info["full_name"]
79 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
80 try:
81 if info["mtype"] == "photo":
82 res = await base64_media(client, message)
83 context["content"].append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
84 elif info["mtype"] == "document":
85 guessed_mime, _ = mimetypes.guess_type(info["file_name"])
86 if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
87 fpath: str = await client.download_media(message, media_path) # type: ignore
88 context["content"].append(
89 {
90 "type": "text",
91 "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
92 }
93 )
94 elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
95 fpath: str = await client.download_media(message, media_path) # type: ignore
96 text = convert2md(path=fpath)
97 Path(fpath).unlink(missing_ok=True)
98 context["content"].append(
99 {
100 "type": "text",
101 "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
102 }
103 )
104 # user message has entity urls, use full html
105 texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
106 clean_texts = clean_context(texts)
107 if not clean_texts:
108 contexts.append(context)
109 continue
110 if role == "user" and add_sender and sender: # noqa: SIM108
111 texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
112 else:
113 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
114 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
115 context["content"].append({"type": "text", "text": texts})
116 except Exception as e:
117 logger.warning(f"Download media from message failed: {e}")
118 contexts.append(context)
119
120 additional_contexts = await parse_openai_chat_additional_contexts(params)
121 if contexts and contexts[-1]["role"] == "user":
122 contexts[-1]["content"].extend(additional_contexts)
123 else:
124 contexts.extend([{"role": "user", "content": additional_contexts}])
125
126 return [ctx for ctx in contexts if ctx.get("content")]
127
128
129async def parse_openai_chat_additional_contexts(params: dict) -> list[dict]:
130 """Parse additional contexts.
131
132 Type: list[dict]
133 {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
134 {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
135 {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
136 {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
137 """
138 if not params.get("additional_contexts"):
139 return []
140
141 # For OpenAI、OpenRouter、Volcengine, please use Responses API instead
142 # Currently, this is for xiaomi mimo only
143 perms = {
144 "image": bool(params.get("allow_image")),
145 "video": bool(params.get("allow_video")),
146 "audio": bool(params.get("allow_audio")),
147 "file": bool(params.get("allow_file")),
148 }
149 types = {
150 "image": "image_url",
151 "video": "video_url",
152 "audio": "inpu_audio",
153 }
154
155 contexts = []
156
157 for item in params["additional_contexts"]:
158 item_type = item.get("type")
159 path = Path(item.get("path", ""))
160 if perms.get(item_type) and path.is_file():
161 if item_type == "file" and (md := convert2md(path=path)):
162 contexts.append({"type": "text", "text": md})
163 continue
164 mime = item.get("mime_type") or guess_mime(path)
165 data_uri = f"data:{mime};base64,{await encode_file(path)}"
166 payload = {"type": types[item_type]} # {"type": "image_url"}
167 payload[types[item_type]] = data_uri # {"image_url": {"data:image/jpeg;base64, ..."}}
168 contexts.append(payload)
169
170 return contexts
171
172
173async def get_openai_response_contexts(client: Client, message: Message, params: dict) -> tuple[str, list[dict]]:
174 """Generate OpenAI response contexts.
175
176 Returns:
177 previous_response_id, contexts
178 """
179
180 async def get_previous_response_id(msg: Message) -> str:
181 """Get previous response id from message.
182
183 Returns:
184 previous_response_id: str
185 """
186 cache_day = params["cache_day"]
187 if cache_day == 0:
188 return ""
189 api_key = params["api_key"]
190 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
191 resp = await head_cf_r2(f"TTL/{cache_day}d/OpenAI/{msg.chat.id}/{msg.id}/{key_hash}")
192 return glom(resp, "Metadata.response_id", default="") or ""
193
194 chains = await full_chain_contexts(client, message, order="desc") # new to old
195 previous_response_id = ""
196 messages = []
197 for msg in chains:
198 if glom(msg, "from_user.id", default=-1) == TID.ME and (pid := await get_previous_response_id(msg)):
199 previous_response_id = pid
200 break
201 messages.append(msg)
202
203 messages.reverse() # old to new
204 if params.get("add_sender") is None:
205 params["add_sender"] = is_multi_user_chat(messages)
206 contexts = [ctx for msg in messages if (ctx := await single_openai_response_context(client, msg, params))]
207 additional_contexts = await parse_openai_response_additional_contexts(client, params)
208 if contexts and contexts[-1]["role"] == "user":
209 contexts[-1]["content"].extend(additional_contexts)
210 else:
211 contexts.extend([{"role": "user", "type": "message", "content": additional_contexts}])
212 return previous_response_id, contexts
213
214
215async def single_openai_response_context(client: Client, message: Message, params: dict) -> dict:
216 """Generate OpenAI response contexts for a single message.
217
218 Returns:
219 {
220 "role": "user or assistant",
221 "content": [],
222 }
223 """
224 info = parse_msg(message, silent=True)
225 role = "assistant" if BOT_TIPS in info["text"] else "user"
226 text_type = "input_text" if role == "user" else "output_text"
227
228 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
229 return {}
230 context = {"role": role, "type": "message", "content": []}
231 if role == "assistant":
232 context["status"] = "completed"
233
234 media_send_as = params.get("openai_media_send_as", "base64")
235 allow_image = bool(params.get("allow_image"))
236 allow_video = bool(params.get("allow_video"))
237 allow_audio = bool(params.get("allow_audio"))
238 allow_file = bool(params.get("allow_file"))
239 sender = info["fwd_full_name"] or info["full_name"]
240 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
241 file_id = ""
242 try:
243 if info["mtype"] == "photo" and allow_image:
244 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
245 context["content"].append({"type": "input_image", "file_id": file_id})
246 if not file_id:
247 res = await base64_media(client, message)
248 context["content"].append({"type": "input_image", "image_url": f"data:image/{res['ext']};base64,{res['base64']}"})
249 elif info["mtype"] == "video" and allow_video:
250 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
251 context["content"].append({"type": "input_video", "file_id": file_id})
252 if not file_id:
253 res = await base64_media(client, message)
254 context["content"].append({"type": "input_video", "video_url": f"data:video/{res['ext']};base64,{res['base64']}"})
255 elif info["mtype"] == "audio" and allow_audio:
256 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
257 context["content"].append({"type": "input_audio", "file_id": file_id})
258 if not file_id:
259 res = await base64_media(client, message)
260 context["content"].append({"type": "input_audio", "audio_url": f"data:audio/{res['ext']};base64,{res['base64']}"})
261 elif info["mtype"] == "document" and allow_file:
262 guessed_mime, _ = mimetypes.guess_type(info["file_name"])
263 if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
264 if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
265 context["content"].append({"type": "input_file", "file_id": file_id})
266 if not file_id:
267 res = await base64_media(client, message)
268 context["content"].append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
269
270 elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
271 fpath: str = await client.download_media(message, media_path) # type: ignore
272 context["content"].append(
273 {
274 "type": text_type,
275 "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
276 }
277 )
278 elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
279 fpath: str = await client.download_media(message, media_path) # type: ignore
280 text = convert2md(path=fpath)
281 Path(fpath).unlink(missing_ok=True)
282 context["content"].append(
283 {
284 "type": text_type,
285 "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
286 }
287 )
288 # user message has entity urls, use full html
289 texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
290 clean_texts = clean_context(texts)
291 if not clean_texts:
292 return context if context["content"] else {}
293 if role == "user" and params.get("add_sender") and sender:
294 texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
295 else:
296 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
297 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
298 context["content"].append({"type": text_type, "text": texts})
299 except Exception as e:
300 logger.warning(f"Download media from message failed: {e}")
301 return context if context["content"] else {}
302
303
304async def parse_openai_response_additional_contexts(client: Client, params: dict) -> list[dict]:
305 """Parse additional contexts.
306
307 Type: list[dict]
308 {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
309 {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
310 {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
311 {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
312 """
313 if not params.get("additional_contexts"):
314 return []
315
316 media_send_as = params.get("openai_media_send_as", "base64")
317 perms = {
318 "image": bool(params.get("allow_image")),
319 "video": bool(params.get("allow_video")),
320 "audio": bool(params.get("allow_audio")),
321 "file": bool(params.get("allow_file")),
322 }
323
324 contexts = []
325
326 for item in params["additional_contexts"]:
327 item_type = item.get("type")
328 path = Path(item.get("path", ""))
329 if perms.get(item_type) and path.is_file():
330 # Base64
331 if media_send_as == "base64":
332 mime = item.get("mime_type") or guess_mime(path)
333 data_uri = f"data:{mime};base64,{await encode_file(path)}"
334 payload = {"type": f"input_{item_type}"} # input_image, input_video, ...
335 if item_type == "file":
336 payload.update({"filename": path.name, "file_data": data_uri})
337 else:
338 payload[f"{item_type}_url"] = data_uri # image_url, video_url, audio_url
339 contexts.append(payload)
340
341 # File ID
342 elif media_send_as == "file_id":
343 hash_id = int(digest(path, length=11, to_int=True))
344 message = Message(id=hash_id, chat=Chat(id=hash_id), document=Document(file_id="", file_unique_id="", file_size=path.stat().st_size))
345 kwargs = {"force_audio_to_aac": True} if item_type == "audio" else {} # handle audio special case
346 if file_id := await get_openai_file_id(client, message, params, fpath=path, keep_file=True, **kwargs):
347 contexts.append({"type": f"input_{item_type}", "file_id": file_id})
348
349 return contexts
350
351
352async def get_openai_file_id(
353 client: Client,
354 message: Message,
355 params: dict,
356 fpath: str | Path | None = None,
357 *,
358 force_audio_to_aac: bool = False,
359 keep_file: bool = False,
360) -> str:
361 def get_real_baseurl() -> str:
362 base_url = str(params["base_url"]) or ""
363 default_headers = params.get("default_headers", {})
364 default_headers = {k.lower(): v for k, v in default_headers.items()}
365 if base_url.startswith("https://gateway.helicone.ai"):
366 helicone_target_url = default_headers.get("helicone-target-url") or ""
367 return base_url.replace("https://gateway.helicone.ai", helicone_target_url.rstrip("/"))
368 if base_url == "https://api.portkey.ai/v1":
369 return default_headers.get("x-portkey-custom-host") or ""
370 return base_url
371
372 if params.get("max_upload_size") and message_bytes(message) > int(params["max_upload_size"]):
373 logger.warning(f"Message-{message.id} size {message_bytes(message)} bytes exceeds max_upload_size {params['max_upload_size']}")
374 return ""
375 api_key = params["api_key"]
376 openai = AsyncOpenAI(
377 base_url=get_real_baseurl(),
378 api_key=api_key,
379 http_client=DefaultAsyncHttpxClient(proxy=params["proxy"]) if params.get("proxy") else None,
380 )
381 cache_day = params.get("cache_day", 30)
382 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
383 r2_key = f"TTL/{cache_day}d/OpenAI/{message.chat.id}/{message.id}/{key_hash}-file_id"
384 r2 = await head_cf_r2(r2_key)
385 if file_id := glom(r2, "Metadata.file_id", default=""):
386 resp = await openai.files.retrieve(file_id=file_id)
387 if resp.status in ["active", "processed"]:
388 return file_id
389
390 if fpath is None:
391 fpath: str = await client.download_media(message) # ty:ignore[invalid-assignment]
392 try:
393 mime = guess_mime(fpath)
394 # hotfix: convert audio to aac
395 if force_audio_to_aac or (message.audio and not str(fpath).endswith(".aac")):
396 fpath: Path = await downsampe_audio(fpath, ext="aac", codec="aac")
397 resp = await openai.files.create(file=Path(fpath), purpose="user_data")
398 # skip waiting for image file
399 if not mime.startswith("image/"):
400 while resp.status in ["processing", "uploaded"]:
401 logger.trace(f"Upload media to OpenAI processing: {resp.model_dump()}")
402 await asyncio.sleep(1)
403 resp = await openai.files.retrieve(file_id=resp.id)
404 if mime.startswith("image/") or resp.status in ["active", "processed"]:
405 await set_cf_r2(r2_key, data=resp.model_dump(), metadata={"file_id": resp.id})
406 if not keep_file:
407 Path(fpath).unlink(missing_ok=True)
408 return resp.id
409 logger.error(f"Upload media to OpenAI failed: {resp.model_dump()}")
410 except Exception as e:
411 logger.error(f"Upload media to OpenAI failed: {e}")
412 return ""
413
414
415async def get_gemini_contexts(
416 client: Client,
417 message: Message,
418 gemini: genai.Client,
419 *,
420 add_sender: bool | None = None,
421 additional_contexts: list[dict] | None = None,
422) -> list[dict]:
423 """Generate Gemini contexts from old to new.
424
425 Returns:
426 contexts: list[dict]
427 """
428 chains = await full_chain_contexts(client, message, order="asc") # old to new
429 if add_sender is None:
430 add_sender = is_multi_user_chat(chains)
431 messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
432 contexts: list = []
433 for msg in messages:
434 info = parse_msg(msg, silent=True)
435 role = "model" if BOT_TIPS in info["text"] else "user"
436 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
437 continue
438 parts = []
439 sender = info["fwd_full_name"] or info["full_name"]
440 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
441 try:
442 if info["mtype"] != "text" and (uploaded := await get_gemini_file_id(client, msg, gemini, info["file_name"], info["mtype"])):
443 parts.append(Part.from_uri(file_uri=uploaded["file_id"], mime_type=uploaded["mime_type"]))
444 elif info["mtype"] == "document":
445 guessed_mime, _ = mimetypes.guess_type(info["file_name"])
446 if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
447 fpath: str = await client.download_media(msg, media_path) # type: ignore
448 parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
449 if Path(info["file_name"]).suffix in MARKDOWN_EXT:
450 fpath: str = await client.download_media(msg, media_path) # type: ignore
451 text = convert2md(path=fpath)
452 Path(fpath).unlink(missing_ok=True)
453 parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
454 texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
455 clean_texts = clean_context(texts)
456 if not clean_texts:
457 contexts.append({"role": role, "parts": parts})
458 continue
459 if role == "user" and add_sender and sender:
460 texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
461 else:
462 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
463 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
464 parts.append(Part.from_text(text=texts))
465 except Exception as e:
466 logger.warning(f"Download media from message failed: {e}")
467 if parts:
468 contexts.append({"role": role, "parts": parts})
469
470 additional_parts = await parse_gemini_additional_contexts(client, gemini, additional_contexts)
471 if contexts and contexts[-1]["role"] == "user":
472 contexts[-1]["parts"].extend(additional_parts)
473 else:
474 contexts.extend([{"role": "user", "parts": additional_parts}])
475 return [ctx for ctx in contexts if len(ctx.get("parts"))]
476
477
478async def parse_gemini_additional_contexts(client: Client, gemini: genai.Client, contexts: list[dict] | None = None) -> list:
479 """Parse additional contexts.
480
481 Type: list[dict]
482 {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
483 {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
484 {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
485 {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
486 {"type": "youtube", "url": "https://www.youtube.com/watch?v=videoid"}
487 """
488 if not contexts:
489 return []
490
491 parts = []
492
493 for item in contexts:
494 item_type = item.get("type", "")
495 if item_type == "youtube":
496 parts.append(Part(file_data=FileData(file_uri=item["url"])))
497 continue
498
499 path = Path(item.get("path", ""))
500 if not path.is_file():
501 continue
502 hash_id = int(digest(path, length=11, to_int=True))
503 message = Message(id=hash_id, chat=Chat(id=hash_id))
504 if uploaded := await get_gemini_file_id(client, message, gemini, path, item_type, keep_file=True):
505 parts.append(Part.from_uri(file_uri=uploaded["file_id"], mime_type=uploaded["mime_type"]))
506 return parts
507
508
509async def get_gemini_file_id(
510 client: Client,
511 message: Message,
512 gemini: genai.Client,
513 fpath: str | Path,
514 mtype: str,
515 *,
516 keep_file: bool = False,
517) -> dict:
518 """Get Gemini file id from message.
519
520 Returns:
521 file_id: str
522 mime_type: str
523 """
524 fname = Path(fpath).name
525 if mtype not in ["video", "photo", "image", "audio", "voice"] and not any(fname.endswith(ext) for ext in GEMINI_EXT):
526 return {}
527
528 cache_hour = AI.GEMINI_FILES_TTL // 3600
529 api_key = glom(gemini, "_api_client.api_key", default="")
530 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
531 r2_key = f"TTL/{cache_hour}h/Gemini/{message.chat.id}/{message.id}/{key_hash}-file_id"
532 r2 = await head_cf_r2(r2_key)
533 app = genai.Client(api_key=api_key, http_options=HttpOptions(async_client_args={"proxy": PROXY.GOOGLE}))
534 if name := glom(r2, "Metadata.name", default=""):
535 try:
536 upload = await app.aio.files.get(name=name)
537 if upload.state == FileState.ACTIVE and upload.uri:
538 return {"file_id": upload.uri, "mime_type": upload.mime_type}
539 except Exception as e:
540 logger.warning(f"Get file id from Gemini failed: {e}")
541 try:
542 if isinstance(fpath, Path) and fpath.is_file():
543 fpath = fpath.as_posix()
544 else:
545 fpath = await client.download_media(message) # type: ignore
546 if mtype in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
547 audio_path = await downsampe_audio(fpath)
548 fpath = audio_path.as_posix()
549 upload = await app.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=fname))
550 while upload.state == FileState.PROCESSING:
551 logger.trace("Waiting for upload to complete...")
552 await asyncio.sleep(1)
553 upload = await app.aio.files.get(name=upload.name) # type: ignore
554 if upload.state == FileState.ACTIVE and upload.uri:
555 await set_cf_r2(r2_key, data=json.loads(upload.model_dump_json()), metadata={"name": upload.name})
556 if not keep_file:
557 Path(fpath).unlink(missing_ok=True)
558 return {"file_id": upload.uri, "mime_type": upload.mime_type}
559 except Exception as e:
560 logger.error(f"Upload media to Gemini failed: {e}")
561 return {}
562
563
564async def get_anthropic_contexts(
565 client: Client,
566 message: Message,
567 anthropic: AsyncAnthropic,
568 cache_hour: int = 0,
569 media_send_as: Literal["base64", "file_id"] = "base64",
570 *,
571 add_sender: bool | None = None,
572) -> list[dict]:
573 """Generate Anthropic contexts."""
574 chains = await full_chain_contexts(client, message, order="asc") # old to new
575 if add_sender is None:
576 add_sender = is_multi_user_chat(chains)
577 messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
578
579 contexts = []
580 for msg in messages:
581 info = parse_msg(msg, silent=True)
582 role = "assistant" if BOT_TIPS in info["text"] else "user"
583 if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
584 continue
585
586 context = {"role": role, "content": []}
587 sender = info["fwd_full_name"] or info["full_name"]
588 media_path = DOWNLOAD_DIR + "/" + info["file_name"]
589 file_id = ""
590 try:
591 if info["mtype"] == "photo":
592 if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
593 context["content"].append({"type": "image", "source": {"type": "file", "file_id": file_id}})
594 if not file_id:
595 res = await base64_media(client, msg)
596 context["content"].append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
597
598 elif info["mtype"] == "document":
599 guessed_mime, _ = mimetypes.guess_type(info["file_name"])
600 if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
601 if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
602 context["content"].append({"type": "document", "source": {"type": "file", "file_id": file_id}})
603 if not file_id:
604 res = await base64_media(client, msg)
605 context["content"].append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
606
607 elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
608 fpath: str = await client.download_media(msg, media_path) # type: ignore
609 context["content"].append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
610
611 elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
612 fpath: str = await client.download_media(msg, media_path) # type: ignore
613 text = convert2md(path=fpath)
614 Path(fpath).unlink(missing_ok=True)
615 context["content"].append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"})
616 # user message has entity urls, use full html
617 texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
618 clean_texts = clean_context(texts)
619 if not clean_texts:
620 contexts.append(context)
621 continue
622 if role == "user" and add_sender and sender: # noqa: SIM108
623 texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
624 else:
625 texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
626 texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
627 context["content"].append({"type": "text", "text": texts})
628 except Exception as e:
629 logger.warning(f"Download media from message failed: {e}")
630 contexts.append(context)
631
632 return [ctx for ctx in contexts if ctx.get("content")]
633
634
635async def get_anthropic_file_id(client: Client, message: Message, anthropic: AsyncAnthropic, cache_hour: int) -> str:
636 api_key: str = anthropic.api_key # ty:ignore[invalid-assignment]
637 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
638 cache_hour = cache_hour or 12
639 r2_key = f"TTL/{cache_hour}h/Anthropic/{key_hash}/{message.chat.id}/{message.id}-file_id"
640 r2 = await head_cf_r2(r2_key)
641 if file_id := glom(r2, "Metadata.file_id", default=""):
642 return file_id
643 fpath: str = await client.download_media(message) # type: ignore
644 try:
645 resp = await anthropic.beta.files.upload(file=Path(fpath))
646 if glom(resp, "id", default=""):
647 return resp.id
648 logger.error(f"Upload media to Anthropic failed: {resp.model_dump()}")
649 except Exception as e:
650 logger.error(f"Upload media to Anthropic failed: {e}")
651 return ""
652
653
654async def full_chain_contexts(client: Client, message: Message, order: Literal["asc", "desc"] = "asc") -> list[Message]:
655 """Get all messages in the reply chain.
656
657 Default order is from oldest to newest.
658 """
659 chains = [message]
660 while message.reply_to_message:
661 message = message.reply_to_message
662 chains.append(message)
663 messages: list[Message] = []
664 for msg in chains:
665 groups = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
666 messages.extend(groups)
667 messages = [m for m in messages if isinstance(m, Message)]
668 return sorted(messages, key=lambda x: x.id, reverse=order == "desc")
669
670
671async def context_types(client: Client, message: Message, additional_contexts: list[dict]) -> dict:
672 """Get context types of a message."""
673 text = False
674 video = False
675 audio = False
676 image = False
677 file = False
678 youtube = False
679 for msg in await full_chain_contexts(client, message):
680 if msg.audio:
681 audio = True
682 if msg.photo:
683 image = True
684 if msg.video:
685 video = True
686 if msg.document:
687 mime = glom(msg, "document.mime_type", default="") or ""
688 fname = glom(msg, "document.file_name", default="") or ""
689 if mime.startswith("image/"):
690 image = True
691 elif mime.startswith("audio/") or Path(fname).suffix in AUDIO_FORMAT:
692 audio = True
693 elif mime.startswith("video/") or Path(fname).suffix in VIDEO_FORMAT:
694 video = True
695 elif mime.startswith("text/") or Path(fname).suffix in TXT_EXT or Path(fname).suffix in MARKDOWN_EXT:
696 text = True
697 else:
698 file = True
699 if msg.text or msg.caption:
700 text = True
701 # additional_contexts are parsed from function `parse_summary_sources` in `src/summarize/utils.py`
702 for ctx in additional_contexts:
703 if ctx["type"] == "image":
704 image = True
705 if ctx["type"] == "video":
706 video = True
707 if ctx["type"] == "audio":
708 audio = True
709 if ctx["type"] == "file":
710 file = True
711 if ctx["type"] == "youtube":
712 youtube = True
713
714 return {"text": text, "video": video, "audio": audio, "image": image, "file": file, "youtube": youtube}
715
716
717def message_bytes(message: Message) -> int:
718 """Count bytes of a message.
719
720 Note:
721 This function only counts bytes of media files, not text messages.
722 """
723 return glom(message, Coalesce("photo.sizes.-1.file_size", "video.file_size", "document.file_size"), default=0)
724
725
726def is_multi_user_chat(messages: list[Message]) -> bool:
727 """Check if this chat history group has multiple users."""
728 uids = {glom(x, "from_user.id", default=0) for x in messages}
729 uids.discard(TID.ME)
730 uids.discard(0)
731 return len(uids) > 1
732
733
734async def encode_file(path: str | Path) -> str:
735 """Encode a file to base64."""
736 chunks = []
737 async with await anyio.Path(path).open("rb") as f:
738 while chunk := await f.read(65535): # 64KB
739 chunks.append(base64.b64encode(chunk).decode("utf-8"))
740 return "".join(chunks)