main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import base64
  5import contextlib
  6import hashlib
  7import time
  8from pathlib import Path
  9from typing import TYPE_CHECKING, Literal
 10
 11from anthropic import AsyncAnthropic
 12from glom import glom
 13from google import genai
 14from google.genai.types import FileState, Part, UploadFileConfig
 15from loguru import logger
 16from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 17from pyrogram.client import Client
 18from pyrogram.types import Message
 19
 20from ai.utils import BOT_TIPS, clean_context
 21from asr.utils import GEMINI_AUDIO_EXT, downsampe_audio
 22from config import AI, DOWNLOAD_DIR
 23from database.r2 import head_cf_r2, set_cf_r2
 24from messages.parser import get_thread_id, parse_msg
 25from utils import convert_md, read_text
 26
 27if TYPE_CHECKING:
 28    from io import BytesIO
 29
 30
 31async def base64_media(client: Client, message: Message) -> dict:
 32    data: BytesIO = await client.download_media(message, in_memory=True)  # type: ignore
 33    logger.debug(f"Downloaded message media: {data.name}")
 34
 35    ext = Path(data.name).suffix.removeprefix(".").replace("jpg", "jpeg")
 36
 37    # image, video
 38    b64_encoding = base64.b64encode(data.getvalue()).decode("utf-8")
 39
 40    # text document
 41    value = ""
 42    with contextlib.suppress(Exception):
 43        value = data.getvalue().decode("utf-8")
 44    return {
 45        "ext": ext,
 46        "base64": b64_encoding,
 47        "value": value,
 48    }
 49
 50
 51async def get_openai_completion_contexts(client: Client, message: Message) -> list[dict]:
 52    """Generate OpenAI chat completion contexts."""
 53    messages = [message]
 54    while message.reply_to_message:
 55        message = message.reply_to_message
 56        messages.append(message)
 57    messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1]  # old to new
 58    return [ctx for msg in messages if (ctx := await single_openai_chat_context(client, msg))]
 59
 60
 61async def single_openai_chat_context(client: Client, message: Message) -> dict:
 62    """Generate OpenAI chat completion contexts for a single message.
 63
 64    Returns:
 65    {
 66        "role": "user or assistant",
 67        "content": [],
 68    }
 69    """
 70    info = parse_msg(message, silent=True)
 71    role = "assistant" if BOT_TIPS in info["text"] else "user"
 72
 73    if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
 74        return {}
 75
 76    extra_txt_extensions = [".sh", ".json", ".xml"]  # treat these as txt file
 77    extra_markdown_extensions = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx"]  # convert to markdown
 78
 79    messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
 80    contexts = []
 81    for msg in messages:
 82        info = parse_msg(msg, silent=True)
 83        sender = info["fwd_full_name"] or info["full_name"]
 84        media_path = DOWNLOAD_DIR + "/" + info["file_name"]
 85        try:
 86            if info["mtype"] == "photo":
 87                res = await base64_media(client, msg)
 88                contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
 89            elif info["mtype"] == "document":
 90                if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
 91                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
 92                    contexts.append(
 93                        {
 94                            "type": "text",
 95                            "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
 96                        }
 97                    )
 98                elif Path(info["file_name"]).suffix in extra_markdown_extensions:
 99                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
100                    text = convert_md(fpath)
101                    Path(fpath).unlink(missing_ok=True)
102                    contexts.append(
103                        {
104                            "type": "text",
105                            "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
106                        }
107                    )
108            # user message has entity urls, use full html
109            clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
110            if not clean_texts:
111                continue
112            if role == "user" and sender:  # noqa: SIM108
113                texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
114            else:
115                texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
116            texts = texts.removeprefix("<quote></quote>\n")  # remove quote mark if no quote_text
117            contexts.append({"type": "text", "text": texts})
118        except Exception as e:
119            logger.warning(f"Download media from message failed: {e}")
120            continue
121    return {"role": role, "content": contexts} if contexts else {}
122
123
124async def get_openai_response_contexts(client: Client, message: Message, openai_params: dict) -> tuple[str, list[dict]]:
125    """Generate OpenAI response contexts.
126
127    Returns:
128        previous_response_id, contexts
129    """
130
131    async def get_previous_response_id(msg: Message) -> str:
132        """Get previous response id from message.
133
134        Returns:
135            previous_response_id: str
136        """
137        api_key = openai_params["api_key"]
138        model_id = openai_params["model_id"]
139        cache_day = openai_params["cache_day"]
140        key_hash = hashlib.sha256(api_key.encode()).hexdigest()
141        tid = get_thread_id(msg)
142        resp = await head_cf_r2(f"TTL/{cache_day}d/OpenAI/{model_id}/{key_hash}/{msg.chat.id}/{msg.id}{'/' + str(tid) if tid else ''}")
143        return glom(resp, "Metadata.response_id", default="") or ""
144
145    previous_response_id = ""
146    messages = [message]
147    while message.reply_to_message and not previous_response_id:
148        message = message.reply_to_message
149        if pid := await get_previous_response_id(message):
150            previous_response_id = pid
151            break
152        messages.append(message)
153    messages.reverse()  # old to new
154    return previous_response_id, [ctx for msg in messages if (ctx := await single_openai_response_context(client, msg, openai_params))]
155
156
157async def single_openai_response_context(client: Client, message: Message, openai_params: dict) -> dict:
158    """Generate OpenAI response contexts for a single message.
159
160    Returns:
161    {
162        "role": "user or assistant",
163        "content": [],
164    }
165    """
166    info = parse_msg(message, silent=True)
167    role = "assistant" if BOT_TIPS in info["text"] else "user"
168
169    if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
170        return {}
171
172    extra_txt_extensions = [".sh", ".json", ".xml"]  # treat these as txt file
173    extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"]  # convert to markdown
174
175    messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
176    contexts = []
177    for msg in messages:
178        info = parse_msg(msg, silent=True)
179        sender = info["fwd_full_name"] or info["full_name"]
180        media_path = DOWNLOAD_DIR + "/" + info["file_name"]
181        media_send_as = openai_params.get("openai_media_send_as", "file_id")
182        file_id = ""
183        try:
184            if info["mtype"] == "photo":
185                if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
186                    contexts.append({"type": "input_image", "file_id": file_id, "detail": "high"})
187                if not file_id:
188                    res = await base64_media(client, msg)
189                    contexts.append({"type": "input_image", "image_url": f"data:image/{res['ext']};base64,{res['base64']}", "detail": "high"})
190
191            elif info["mtype"] == "video":
192                if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
193                    contexts.append({"type": "input_video", "file_id": file_id})
194                if not file_id:
195                    res = await base64_media(client, msg)
196                    contexts.append({"type": "input_video", "image_url": f"data:video/{res['ext']};base64,{res['base64']}"})
197
198            elif info["mtype"] == "document":
199                if info["mime_type"] == "application/pdf":
200                    if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
201                        contexts.append({"type": "input_file", "file_id": file_id})
202                    if not file_id:
203                        res = await base64_media(client, msg)
204                        contexts.append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
205
206                elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
207                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
208                    contexts.append(
209                        {
210                            "type": "input_text",
211                            "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
212                        }
213                    )
214                elif Path(info["file_name"]).suffix in extra_markdown_extensions:
215                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
216                    text = convert_md(fpath)
217                    Path(fpath).unlink(missing_ok=True)
218                    contexts.append(
219                        {
220                            "type": "input_text",
221                            "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
222                        }
223                    )
224            # user message has entity urls, use full html
225            clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
226            if not clean_texts:
227                continue
228            if role == "user" and sender:  # noqa: SIM108
229                texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
230            else:
231                texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
232            texts = texts.removeprefix("<quote></quote>\n")  # remove quote mark if no quote_text
233            contexts.append({"type": "input_text", "text": texts})
234        except Exception as e:
235            logger.warning(f"Download media from message failed: {e}")
236            continue
237    return {"role": role, "content": contexts} if contexts else {}
238
239
240async def get_openai_file_id(client: Client, message: Message, openai_params: dict, mtype: str) -> str:
241    def get_real_baseurl() -> str:
242        base_url = str(openai_params["base_url"]) or ""
243        default_headers = openai_params.get("default_headers", {})
244        default_headers = {k.lower(): v for k, v in default_headers.items()}
245        if base_url.startswith("https://gateway.helicone.ai"):
246            helicone_target_url = default_headers.get("helicone-target-url") or ""
247            return base_url.replace("https://gateway.helicone.ai", helicone_target_url.rstrip("/"))
248        return base_url
249
250    if mtype not in ["photo", "video", "document"]:
251        return ""
252    if not openai_params["allow_image"] and mtype == "photo":
253        return ""
254    if not openai_params["allow_video"] and mtype == "video":
255        return ""
256    if not openai_params["allow_file"] and mtype == "document":
257        return ""
258
259    cache_day = openai_params.get("cache_day", 30)
260    api_key = openai_params["api_key"]
261    model_id = openai_params["model_id"]
262    key_hash = hashlib.sha256(api_key.encode()).hexdigest()
263    tid = get_thread_id(message)
264    r2_key = f"TTL/{cache_day}d/OpenAI/{model_id}/{key_hash}/{message.chat.id}/{message.id}{'/' + str(tid) if tid else ''}-file_id"
265    r2 = await head_cf_r2(r2_key)
266    if file_id := glom(r2, "Metadata.file_id", default=""):
267        return file_id
268
269    openai = AsyncOpenAI(
270        base_url=get_real_baseurl(),
271        api_key=api_key,
272        http_client=DefaultAsyncHttpxClient(proxy=openai_params["proxy"]) if openai_params.get("proxy") else None,
273    )
274    fpath: str = await client.download_media(message)  # type: ignore
275    extra_body = {"expire_at": int(time.time()) + 3600 * 24 * cache_day}
276
277    preprocess_configs = {}
278    if message.video:
279        duration = glom(message, "video.duration", default=1e8)
280        ratio = int(duration // 300)
281        fps = ratio * 0.5
282        if fps < 0.5:
283            fps = 0.5
284        elif fps > 5.0:
285            fps = 5.0
286        preprocess_configs = {"video": {"fps": fps, "model": openai_params["model_id"]}}
287    if preprocess_configs:
288        extra_body["preprocess_configs"] = preprocess_configs
289    try:
290        resp = await openai.files.create(file=Path(fpath), purpose="user_data", extra_body=extra_body)
291        while resp.status == "processing":
292            logger.trace(f"Upload media to OpenAI processing: {resp.model_dump()}")
293            await asyncio.sleep(3)
294            resp = await openai.files.retrieve(file_id=resp.id)
295        if resp.status == "active":
296            Path(fpath).unlink(missing_ok=True)
297            await set_cf_r2(r2_key, data=resp.model_dump(), metadata={"file_id": resp.id})
298            return resp.id
299        logger.error(f"Upload media to OpenAI failed: {resp.model_dump()}")
300    except Exception as e:
301        logger.error(f"Upload media to OpenAI failed: {e}")
302    return ""
303
304
305async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Client) -> list[dict]:
306    """Generate Gemini contexts from old to new.
307
308    Returns:
309        contexts: list[dict]
310    """
311    ctx_messages = [message]
312    while message.reply_to_message:
313        message = message.reply_to_message
314        ctx_messages.append(message)
315    ctx_messages = ctx_messages[: int(AI.MAX_CONTEXTS_NUM)][::-1]  # old to new
316    contexts = []
317    for m in ctx_messages:
318        info = parse_msg(m, silent=True)
319        role = "model" if BOT_TIPS in info["text"] else "user"
320        if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
321            continue
322        # gemini has built-in support for these extensions
323        gemini_extensions = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
324        # gemini has built-in support for these mime types
325        gemini_mime_types = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
326        txt_extensions = [".txt", ".js", ".py", ".md", ".sh", ".json"]  # treat these as txt file
327        extra_markdown_extensions = [".docx", ".pptx", ".xls", ".xlsx", ".epub"]  # convert to markdown
328        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
329        parts = []
330        for msg in group_messages:
331            info = parse_msg(msg, silent=True)
332            sender = info["fwd_full_name"] or info["full_name"]
333            media_path = DOWNLOAD_DIR + "/" + info["file_name"]
334            try:
335                if info["mtype"] in ["video", "photo", "audio", "voice"] or info["mime_type"] in gemini_mime_types or any(info["file_name"].endswith(ext) for ext in gemini_extensions):
336                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
337                    if info["mtype"] in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
338                        audio_path = await downsampe_audio(fpath)
339                        fpath = audio_path.as_posix()
340                    upload = await gemini.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=info["file_name"] or f"send from {sender}"))
341                    while upload.state == FileState.PROCESSING:
342                        logger.trace("Waiting for upload to complete...")
343                        await asyncio.sleep(1)
344                        upload = await gemini.aio.files.get(name=upload.name)  # type: ignore
345                    if upload.state == FileState.ACTIVE and upload.uri:
346                        parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
347                    Path(fpath).unlink(missing_ok=True)
348                elif info["mtype"] == "document":
349                    if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
350                        fpath: str = await client.download_media(msg, media_path)  # type: ignore
351                        parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
352                    if Path(info["file_name"]).suffix in extra_markdown_extensions:
353                        fpath: str = await client.download_media(msg, media_path)  # type: ignore
354                        text = convert_md(fpath)
355                        Path(fpath).unlink(missing_ok=True)
356                        parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
357                clean_texts = clean_context(info["text"])
358                if not clean_texts:
359                    continue
360                if role == "user" and sender:  # noqa: SIM108
361                    texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
362                else:
363                    texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
364                texts = texts.removeprefix("<quote></quote>\n")  # remove quote mark if no quote_text
365                parts.append(Part.from_text(text=texts))
366            except Exception as e:
367                logger.warning(f"Download media from message failed: {e}")
368                continue
369        if parts:
370            contexts.append({"role": role, "parts": parts})
371    return contexts
372
373
374async def get_anthropic_contexts(client: Client, message: Message, **kwargs) -> list[dict]:
375    """Generate Anthropic contexts."""
376    messages = [message]
377    while message.reply_to_message:
378        message = message.reply_to_message
379        messages.append(message)
380    messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1]  # old to new
381    return [ctx for msg in messages if (ctx := await single_anthropic_context(client, msg, **kwargs))]
382
383
384async def single_anthropic_context(
385    client: Client,
386    message: Message,
387    anthropic: AsyncAnthropic,
388    cache_hour: int = 0,
389    media_send_as: Literal["base64", "file_id"] = "file_id",
390) -> dict:
391    """Generate Anthropic contexts for a single message.
392
393    Returns:
394    {
395        "role": "user or assistant",
396        "content": [],
397    }
398    """
399    info = parse_msg(message, silent=True)
400    role = "assistant" if BOT_TIPS in info["text"] else "user"
401
402    if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
403        return {}
404
405    extra_txt_extensions = [".sh", ".json", ".xml"]  # treat these as txt file
406    extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"]  # convert to markdown
407
408    messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
409    contexts = []
410    for msg in messages:
411        info = parse_msg(msg, silent=True)
412        sender = info["fwd_full_name"] or info["full_name"]
413        media_path = DOWNLOAD_DIR + "/" + info["file_name"]
414        file_id = ""
415        try:
416            if info["mtype"] == "photo":
417                if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
418                    contexts.append({"type": "image", "source": {"type": "file", "file_id": file_id}})
419                if not file_id:
420                    res = await base64_media(client, msg)
421                    contexts.append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
422
423            elif info["mtype"] == "document":
424                if info["mime_type"] == "application/pdf":
425                    if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
426                        contexts.append({"type": "document", "source": {"type": "file", "file_id": file_id}})
427                    if not file_id:
428                        res = await base64_media(client, msg)
429                        contexts.append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
430
431                elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
432                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
433                    contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
434
435                elif Path(info["file_name"]).suffix in extra_markdown_extensions:
436                    fpath: str = await client.download_media(msg, media_path)  # type: ignore
437                    text = convert_md(fpath)
438                    Path(fpath).unlink(missing_ok=True)
439                    contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"})
440            # user message has entity urls, use full html
441            clean_texts = clean_context(info["html"] or info["text"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
442            if not clean_texts:
443                continue
444            if role == "user" and sender:  # noqa: SIM108
445                texts = f"<quote>{info['quote_text']}</quote>\n[username]: {sender}\n[message]:\n{clean_texts}"
446            else:
447                texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
448            texts = texts.removeprefix("<quote></quote>\n")  # remove quote mark if no quote_text
449            contexts.append({"type": "text", "text": texts})
450        except Exception as e:
451            logger.warning(f"Download media from message failed: {e}")
452            continue
453    return {"role": role, "content": contexts} if contexts else {}
454
455
456async def get_anthropic_file_id(client: Client, message: Message, anthropic: AsyncAnthropic, cache_hour: int) -> str:
457    api_key: str = anthropic.api_key  # ty:ignore[invalid-assignment]
458    key_hash = hashlib.sha256(api_key.encode()).hexdigest()
459    tid = get_thread_id(message)
460    cache_hour = cache_hour or 12
461    r2_key = f"TTL/{cache_hour}h/Anthropic/{key_hash}/{message.chat.id}/{message.id}{'/' + str(tid) if tid else ''}-file_id"
462    r2 = await head_cf_r2(r2_key)
463    if file_id := glom(r2, "Metadata.file_id", default=""):
464        return file_id
465    fpath: str = await client.download_media(message)  # type: ignore
466    try:
467        resp = await anthropic.beta.files.upload(file=Path(fpath))
468        if glom(resp, "id", default=""):
469            return resp.id
470        logger.error(f"Upload media to Anthropic failed: {resp.model_dump()}")
471    except Exception as e:
472        logger.error(f"Upload media to Anthropic failed: {e}")
473    return ""