Commit 34ccf47
Changed files (3)
src
src/ai/texts/claude.py
@@ -34,7 +34,7 @@ async def anthropic_responses(
anthropic_responses_config: str | dict = "",
anthropic_proxy: str | None = PROXY.ANTHROPIC,
cache_response_ttl: int = 0,
- anthropic_media_send_as: Literal["base64", "file_id"] = "file_id",
+ anthropic_media_send_as: Literal["base64", "file_id"] = "base64",
anthropic_append_citation: bool = True,
skills: str = "",
hide_thinking: bool = False,
src/ai/texts/contexts.py
@@ -4,6 +4,7 @@ import asyncio
import base64
import contextlib
import hashlib
+import json
import mimetypes
from pathlib import Path
from typing import TYPE_CHECKING, Literal
@@ -11,7 +12,7 @@ from typing import TYPE_CHECKING, Literal
from anthropic import AsyncAnthropic
from glom import Coalesce, glom
from google import genai
-from google.genai.types import FileState, Part, UploadFileConfig
+from google.genai.types import FileState, HttpOptions, Part, UploadFileConfig
from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
@@ -19,7 +20,7 @@ from pyrogram.types import Message
from ai.utils import BOT_TIPS, clean_context
from asr.utils import GEMINI_AUDIO_EXT, downsampe_audio
-from config import AI, DOWNLOAD_DIR, TID
+from config import AI, DOWNLOAD_DIR, PROXY, TID
from database.r2 import head_cf_r2, set_cf_r2
from messages.parser import parse_msg
from utils import convert_md, read_text
@@ -27,6 +28,13 @@ from utils import convert_md, read_text
if TYPE_CHECKING:
from io import BytesIO
+TXT_EXT = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
+MARKDOWN_EXT = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
+# gemini has built-in support for these extensions
+GEMINI_EXT = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
+# gemini has built-in support for these mime types
+GEMINI_MIME = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
+
async def base64_media(client: Client, message: Message) -> dict:
data: BytesIO = await client.download_media(message, in_memory=True) # type: ignore
@@ -50,59 +58,41 @@ async def base64_media(client: Client, message: Message) -> dict:
async def get_openai_completion_contexts(client: Client, message: Message, *, add_sender: bool | None = None) -> list[dict]:
"""Generate OpenAI chat completion contexts."""
- messages = [message]
- while message.reply_to_message:
- message = message.reply_to_message
- messages.append(message)
+ chains = await full_chain_contexts(client, message, order="asc") # old to new
if add_sender is None:
- add_sender = is_multi_user_chat(messages)
- messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
- return [ctx for msg in messages if (ctx := await single_openai_chat_context(client, msg, add_sender=add_sender))]
-
-
-async def single_openai_chat_context(client: Client, message: Message, *, add_sender: bool) -> dict:
- """Generate OpenAI chat completion contexts for a single message.
-
- Returns:
- {
- "role": "user or assistant",
- "content": [],
- }
- """
- info = parse_msg(message, silent=True)
- role = "assistant" if BOT_TIPS in info["text"] else "user"
-
- if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
- return {}
-
- extra_txt_extensions = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
- extra_markdown_extensions = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
-
- messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
+ add_sender = is_multi_user_chat(chains)
+ messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
contexts = []
+
for msg in messages:
info = parse_msg(msg, silent=True)
+ role = "assistant" if BOT_TIPS in info["text"] else "user"
+
+ if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
+ continue
+
+ context = {"role": role, "content": []}
sender = info["fwd_full_name"] or info["full_name"]
media_path = DOWNLOAD_DIR + "/" + info["file_name"]
try:
if info["mtype"] == "photo":
- res = await base64_media(client, msg)
- contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
+ res = await base64_media(client, message)
+ context["content"].append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
elif info["mtype"] == "document":
guessed_mime, _ = mimetypes.guess_type(info["file_name"])
- if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
- contexts.append(
+ if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
+ fpath: str = await client.download_media(message, media_path) # type: ignore
+ context["content"].append(
{
"type": "text",
"text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
}
)
- elif Path(info["file_name"]).suffix in extra_markdown_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
+ elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
+ fpath: str = await client.download_media(message, media_path) # type: ignore
text = convert_md(fpath)
Path(fpath).unlink(missing_ok=True)
- contexts.append(
+ context["content"].append(
{
"type": "text",
"text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
@@ -112,17 +102,19 @@ async def single_openai_chat_context(client: Client, message: Message, *, add_se
texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
clean_texts = clean_context(texts)
if not clean_texts:
+ contexts.append(context)
continue
- if role == "user" and add_sender and sender:
+ if role == "user" and add_sender and sender: # noqa: SIM108
texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
else:
texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
- contexts.append({"type": "text", "text": texts})
+ context["content"].append({"type": "text", "text": texts})
except Exception as e:
logger.warning(f"Download media from message failed: {e}")
- continue
- return {"role": role, "content": contexts} if contexts else {}
+ contexts.append(context)
+
+ return [ctx for ctx in contexts if ctx.get("content")]
async def get_openai_response_contexts(client: Client, message: Message, params: dict) -> tuple[str, list[dict]]:
@@ -142,19 +134,19 @@ async def get_openai_response_contexts(client: Client, message: Message, params:
if cache_day == 0:
return ""
api_key = params["api_key"]
- model_id = params["model_id"]
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
- resp = await head_cf_r2(f"TTL/{cache_day}d/OpenAI/{msg.chat.id}/{msg.id}/{model_id}/{key_hash}")
+ resp = await head_cf_r2(f"TTL/{cache_day}d/OpenAI/{msg.chat.id}/{msg.id}/{key_hash}")
return glom(resp, "Metadata.response_id", default="") or ""
+ chains = await full_chain_contexts(client, message, order="desc") # new to old
previous_response_id = ""
- messages = [message]
- while message.reply_to_message and not previous_response_id:
- message = message.reply_to_message
- if pid := await get_previous_response_id(message):
+ messages = []
+ for msg in chains:
+ if glom(msg, "from_user.id", default=-1) == TID.ME and (pid := await get_previous_response_id(msg)):
previous_response_id = pid
break
- messages.append(message)
+ messages.append(msg)
+
messages.reverse() # old to new
if params.get("add_sender") is None:
params["add_sender"] = is_multi_user_chat(messages)
@@ -176,86 +168,78 @@ async def single_openai_response_context(client: Client, message: Message, param
if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
return {}
+ context = {"role": role, "type": "message", "content": []}
+ if role == "assistant":
+ context["status"] = "completed"
- extra_txt_extensions = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
- extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
-
- messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
media_send_as = params.get("openai_media_send_as", "base64")
allow_image = bool(params.get("allow_image"))
allow_video = bool(params.get("allow_video"))
allow_audio = bool(params.get("allow_audio"))
allow_file = bool(params.get("allow_file"))
- contexts = []
- for msg in messages:
- info = parse_msg(msg, silent=True)
- sender = info["fwd_full_name"] or info["full_name"]
- media_path = DOWNLOAD_DIR + "/" + info["file_name"]
- file_id = ""
- try:
- if info["mtype"] == "photo" and allow_image:
- if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, params)):
- contexts.append({"type": "input_image", "file_id": file_id})
- if not file_id:
- res = await base64_media(client, msg)
- contexts.append({"type": "input_image", "image_url": f"data:image/{res['ext']};base64,{res['base64']}"})
- elif info["mtype"] == "video" and allow_video:
- if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, params)):
- contexts.append({"type": "input_video", "file_id": file_id})
- if not file_id:
- res = await base64_media(client, msg)
- contexts.append({"type": "input_video", "video_url": f"data:video/{res['ext']};base64,{res['base64']}"})
- elif info["mtype"] == "audio" and allow_audio:
- if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, params)):
- contexts.append({"type": "input_audio", "file_id": file_id})
+ sender = info["fwd_full_name"] or info["full_name"]
+ media_path = DOWNLOAD_DIR + "/" + info["file_name"]
+ file_id = ""
+ try:
+ if info["mtype"] == "photo" and allow_image:
+ if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
+ context["content"].append({"type": "input_image", "file_id": file_id})
+ if not file_id:
+ res = await base64_media(client, message)
+ context["content"].append({"type": "input_image", "image_url": f"data:image/{res['ext']};base64,{res['base64']}"})
+ elif info["mtype"] == "video" and allow_video:
+ if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
+ context["content"].append({"type": "input_video", "file_id": file_id})
+ if not file_id:
+ res = await base64_media(client, message)
+ context["content"].append({"type": "input_video", "video_url": f"data:video/{res['ext']};base64,{res['base64']}"})
+ elif info["mtype"] == "audio" and allow_audio:
+ if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
+ context["content"].append({"type": "input_audio", "file_id": file_id})
+ if not file_id:
+ res = await base64_media(client, message)
+ context["content"].append({"type": "input_audio", "audio_url": f"data:audio/{res['ext']};base64,{res['base64']}"})
+ elif info["mtype"] == "document" and allow_file:
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
+ if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, message, params)):
+ context["content"].append({"type": "input_file", "file_id": file_id})
if not file_id:
- res = await base64_media(client, msg)
- contexts.append({"type": "input_audio", "audio_url": f"data:audio/{res['ext']};base64,{res['base64']}"})
- elif info["mtype"] == "document" and allow_file:
- guessed_mime, _ = mimetypes.guess_type(info["file_name"])
- if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
- if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, params)):
- contexts.append({"type": "input_file", "file_id": file_id})
- if not file_id:
- res = await base64_media(client, msg)
- contexts.append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
-
- elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
- contexts.append(
- {
- "type": text_type,
- "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
- }
- )
- elif Path(info["file_name"]).suffix in extra_markdown_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
- text = convert_md(fpath)
- Path(fpath).unlink(missing_ok=True)
- contexts.append(
- {
- "type": text_type,
- "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
- }
- )
- # user message has entity urls, use full html
- texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
- clean_texts = clean_context(texts)
- if not clean_texts:
- continue
- if role == "user" and params.get("add_sender") and sender:
- texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
- else:
- texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
- texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
- contexts.append({"type": text_type, "text": texts})
- except Exception as e:
- logger.warning(f"Download media from message failed: {e}")
- continue
- if not contexts:
- return {}
- extra = {"status": "completed"} if role == "assistant" else {}
- return {"role": role, "type": "message", "content": contexts, **extra}
+ res = await base64_media(client, message)
+ context["content"].append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
+
+ elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
+ fpath: str = await client.download_media(message, media_path) # type: ignore
+ context["content"].append(
+ {
+ "type": text_type,
+ "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}",
+ }
+ )
+ elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
+ fpath: str = await client.download_media(message, media_path) # type: ignore
+ text = convert_md(fpath)
+ Path(fpath).unlink(missing_ok=True)
+ context["content"].append(
+ {
+ "type": text_type,
+ "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}",
+ }
+ )
+ # user message has entity urls, use full html
+ texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
+ clean_texts = clean_context(texts)
+ if not clean_texts:
+ return context if context["content"] else {}
+ if role == "user" and params.get("add_sender") and sender:
+ texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
+ else:
+ texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
+ texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
+ context["content"].append({"type": text_type, "text": texts})
+ except Exception as e:
+ logger.warning(f"Download media from message failed: {e}")
+ return context if context["content"] else {}
async def get_openai_file_id(client: Client, message: Message, params: dict) -> str:
@@ -275,9 +259,8 @@ async def get_openai_file_id(client: Client, message: Message, params: dict) ->
return ""
cache_day = params.get("cache_day", 30)
api_key = params["api_key"]
- model_id = params["model_id"]
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
- r2_key = f"TTL/{cache_day}d/OpenAI/{message.chat.id}/{message.id}/{model_id}/{key_hash}-file_id"
+ r2_key = f"TTL/{cache_day}d/OpenAI/{message.chat.id}/{message.id}/{key_hash}-file_id"
r2 = await head_cf_r2(r2_key)
if file_id := glom(r2, "Metadata.file_id", default=""):
return file_id
@@ -313,159 +296,161 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
Returns:
contexts: list[dict]
"""
- ctx_messages = [message]
- while message.reply_to_message:
- message = message.reply_to_message
- ctx_messages.append(message)
+ chains = await full_chain_contexts(client, message, order="asc") # old to new
if add_sender is None:
- add_sender = is_multi_user_chat(ctx_messages)
- ctx_messages = ctx_messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
+ add_sender = is_multi_user_chat(chains)
+ messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
contexts = []
- for m in ctx_messages:
- info = parse_msg(m, silent=True)
+ for msg in messages:
+ info = parse_msg(msg, silent=True)
role = "model" if BOT_TIPS in info["text"] else "user"
if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
continue
- # gemini has built-in support for these extensions
- gemini_extensions = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
- # gemini has built-in support for these mime types
- gemini_mime_types = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
- txt_extensions = [".txt", ".js", ".py", ".md", ".sh", ".json"] # treat these as txt file
- extra_markdown_extensions = [".docx", ".pptx", ".xls", ".xlsx", ".epub"] # convert to markdown
- group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
parts = []
- for msg in group_messages:
- info = parse_msg(msg, silent=True)
- sender = info["fwd_full_name"] or info["full_name"]
- media_path = DOWNLOAD_DIR + "/" + info["file_name"]
- try:
- if info["mtype"] in ["video", "photo", "audio", "voice"] or info["mime_type"] in gemini_mime_types or any(info["file_name"].endswith(ext) for ext in gemini_extensions):
+ sender = info["fwd_full_name"] or info["full_name"]
+ media_path = DOWNLOAD_DIR + "/" + info["file_name"]
+ try:
+ if info["mtype"] != "text" and (uploaded := await get_gemini_file_id(client, msg, gemini, info["file_name"], info["mtype"])):
+ parts.append(Part.from_uri(file_uri=uploaded["file_id"], mime_type=uploaded["mime_type"]))
+ elif info["mtype"] == "document":
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
+ fpath: str = await client.download_media(msg, media_path) # type: ignore
+ parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
+ if Path(info["file_name"]).suffix in MARKDOWN_EXT:
fpath: str = await client.download_media(msg, media_path) # type: ignore
- if info["mtype"] in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
- audio_path = await downsampe_audio(fpath)
- fpath = audio_path.as_posix()
- upload = await gemini.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=info["file_name"] or f"send_from_{sender}"))
- while upload.state == FileState.PROCESSING:
- logger.trace("Waiting for upload to complete...")
- await asyncio.sleep(1)
- upload = await gemini.aio.files.get(name=upload.name) # type: ignore
- if upload.state == FileState.ACTIVE and upload.uri:
- parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
+ text = convert_md(fpath)
Path(fpath).unlink(missing_ok=True)
- elif info["mtype"] == "document":
- guessed_mime, _ = mimetypes.guess_type(info["file_name"])
- if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
- parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
- if Path(info["file_name"]).suffix in extra_markdown_extensions:
- fpath: str = await client.download_media(msg, media_path) # type: ignore
- text = convert_md(fpath)
- Path(fpath).unlink(missing_ok=True)
- parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
- texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
- clean_texts = clean_context(texts)
- if not clean_texts:
- continue
- if role == "user" and add_sender and sender:
- texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
- else:
- texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
- texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
- parts.append(Part.from_text(text=texts))
- except Exception as e:
- logger.warning(f"Download media from message failed: {e}")
+ parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
+ texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
+ clean_texts = clean_context(texts)
+ if not clean_texts:
+ contexts.append({"role": role, "parts": parts})
continue
+ if role == "user" and add_sender and sender:
+ texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
+ else:
+ texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
+ texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
+ parts.append(Part.from_text(text=texts))
+ except Exception as e:
+ logger.warning(f"Download media from message failed: {e}")
if parts:
contexts.append({"role": role, "parts": parts})
- return contexts
+ return [ctx for ctx in contexts if len(ctx.get("parts"))]
-async def get_anthropic_contexts(client: Client, message: Message, **kwargs) -> list[dict]:
- """Generate Anthropic contexts."""
- messages = [message]
- while message.reply_to_message:
- message = message.reply_to_message
- messages.append(message)
- if kwargs.get("add_sender") is None:
- kwargs["add_sender"] = is_multi_user_chat(messages)
- messages = messages[: int(AI.MAX_CONTEXTS_NUM)][::-1] # old to new
- return [ctx for msg in messages if (ctx := await single_anthropic_context(client, msg, **kwargs))]
+async def get_gemini_file_id(client: Client, message: Message, gemini: genai.Client, fname: str, mtype: str) -> dict:
+ """Get Gemini file id from message.
+ Returns:
+ file_id: str
+ mime_type: str
+ """
+ if mtype not in ["video", "photo", "audio", "voice"] and mtype not in GEMINI_MIME and not any(fname.endswith(ext) for ext in GEMINI_EXT):
+ return {}
-async def single_anthropic_context(
+ cache_hour = AI.GEMINI_FILES_TTL // 3600
+ api_key = glom(gemini, "_api_client.api_key", default="")
+ key_hash = hashlib.sha256(api_key.encode()).hexdigest()
+ r2_key = f"TTL/{cache_hour}h/Gemini/{message.chat.id}/{message.id}/{key_hash}-file_id"
+ r2 = await head_cf_r2(r2_key)
+ app = genai.Client(api_key=api_key, http_options=HttpOptions(async_client_args={"proxy": PROXY.GOOGLE}))
+ if name := glom(r2, "Metadata.name", default=""):
+ try:
+ upload = await app.aio.files.get(name=name)
+ if upload.state == FileState.ACTIVE and upload.uri:
+ return {"file_id": upload.uri, "mime_type": upload.mime_type}
+ except Exception as e:
+ logger.warning(f"Get file id from Gemini failed: {e}")
+ try:
+ fpath: str = await client.download_media(message) # type: ignore
+ if mtype in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
+ audio_path = await downsampe_audio(fpath)
+ fpath = audio_path.as_posix()
+ upload = await app.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=fname))
+ while upload.state == FileState.PROCESSING:
+ logger.trace("Waiting for upload to complete...")
+ await asyncio.sleep(1)
+ upload = await app.aio.files.get(name=upload.name) # type: ignore
+ if upload.state == FileState.ACTIVE and upload.uri:
+ await set_cf_r2(r2_key, data=json.loads(upload.model_dump_json()), metadata={"name": upload.name})
+ Path(fpath).unlink(missing_ok=True)
+ return {"file_id": upload.uri, "mime_type": upload.mime_type}
+ except Exception as e:
+ logger.error(f"Upload media to Gemini failed: {e}")
+ return {}
+
+
+async def get_anthropic_contexts(
client: Client,
message: Message,
anthropic: AsyncAnthropic,
cache_hour: int = 0,
- media_send_as: Literal["base64", "file_id"] = "file_id",
+ media_send_as: Literal["base64", "file_id"] = "base64",
*,
- add_sender: bool = True,
-) -> dict:
- """Generate Anthropic contexts for a single message.
-
- Returns:
- {
- "role": "user or assistant",
- "content": [],
- }
- """
- info = parse_msg(message, silent=True)
- role = "assistant" if BOT_TIPS in info["text"] else "user"
-
- if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
- return {}
-
- extra_txt_extensions = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
- extra_markdown_extensions = [".html", ".docx", ".pptx", ".xls", ".xlsx"] # convert to markdown
+ add_sender: bool | None = None,
+) -> list[dict]:
+ """Generate Anthropic contexts."""
+ chains = await full_chain_contexts(client, message, order="asc") # old to new
+ if add_sender is None:
+ add_sender = is_multi_user_chat(chains)
+ messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
- messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
contexts = []
for msg in messages:
info = parse_msg(msg, silent=True)
+ role = "assistant" if BOT_TIPS in info["text"] else "user"
+ if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document", "web_page"]:
+ continue
+
+ context = {"role": role, "content": []}
sender = info["fwd_full_name"] or info["full_name"]
media_path = DOWNLOAD_DIR + "/" + info["file_name"]
file_id = ""
try:
if info["mtype"] == "photo":
if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
- contexts.append({"type": "image", "source": {"type": "file", "file_id": file_id}})
+ context["content"].append({"type": "image", "source": {"type": "file", "file_id": file_id}})
if not file_id:
res = await base64_media(client, msg)
- contexts.append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
+ context["content"].append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
elif info["mtype"] == "document":
guessed_mime, _ = mimetypes.guess_type(info["file_name"])
if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
- contexts.append({"type": "document", "source": {"type": "file", "file_id": file_id}})
+ context["content"].append({"type": "document", "source": {"type": "file", "file_id": file_id}})
if not file_id:
res = await base64_media(client, msg)
- contexts.append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
+ context["content"].append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
- elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+ elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in TXT_EXT:
fpath: str = await client.download_media(msg, media_path) # type: ignore
- contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
+ context["content"].append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
- elif Path(info["file_name"]).suffix in extra_markdown_extensions:
+ elif Path(info["file_name"]).suffix in MARKDOWN_EXT:
fpath: str = await client.download_media(msg, media_path) # type: ignore
text = convert_md(fpath)
Path(fpath).unlink(missing_ok=True)
- contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"})
+ context["content"].append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"})
# user message has entity urls, use full html
texts = info["html"] or info["text"] if role == "user" and info["entity_urls"] else info["text"]
clean_texts = clean_context(texts)
if not clean_texts:
+ contexts.append(context)
continue
- if role == "user" and add_sender and sender:
+ if role == "user" and add_sender and sender: # noqa: SIM108
texts = f"<quote>{info['quote_text']}</quote>\n{sender} ({info['time']})\n{clean_texts}"
else:
texts = f"<quote>{info['quote_text']}</quote>\n{clean_texts}"
texts = texts.removeprefix("<quote></quote>\n") # remove quote mark if no quote_text
- contexts.append({"type": "text", "text": texts})
+ context["content"].append({"type": "text", "text": texts})
except Exception as e:
logger.warning(f"Download media from message failed: {e}")
- continue
- return {"role": role, "content": contexts} if contexts else {}
+ contexts.append(context)
+
+ return [ctx for ctx in contexts if ctx.get("content")]
async def get_anthropic_file_id(client: Client, message: Message, anthropic: AsyncAnthropic, cache_hour: int) -> str:
@@ -487,7 +472,11 @@ async def get_anthropic_file_id(client: Client, message: Message, anthropic: Asy
return ""
-async def context_bytes(client: Client, message: Message) -> int:
+async def full_chain_contexts(client: Client, message: Message, order: Literal["asc", "desc"] = "asc") -> list[Message]:
+ """Get all messages in the reply chain.
+
+ Default order is from oldest to newest.
+ """
chains = [message]
while message.reply_to_message:
message = message.reply_to_message
@@ -496,10 +485,22 @@ async def context_bytes(client: Client, message: Message) -> int:
for msg in chains:
groups = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
messages.extend(groups)
- return sum(message_bytes(m) for m in messages)
+ messages = [m for m in messages if isinstance(m, Message)]
+ return sorted(messages, key=lambda x: x.id, reverse=order == "desc")
+
+
+async def context_bytes(client: Client, message: Message) -> int:
+ """Count bytes of all messages in the reply chain."""
+ chains = await full_chain_contexts(client, message)
+ return sum(message_bytes(m) for m in chains)
def message_bytes(message: Message) -> int:
+ """Count bytes of a message.
+
+ Note:
+ This function only counts bytes of media files, not text messages.
+ """
return glom(message, Coalesce("photo.sizes.-1.file_size", "video.file_size", "document.file_size"), default=0)
@@ -507,4 +508,5 @@ def is_multi_user_chat(messages: list[Message]) -> bool:
"""Check if this chat history group has multiple users."""
uids = {glom(x, "from_user.id", default=0) for x in messages}
uids.discard(TID.ME)
+ uids.discard(0)
return len(uids) > 1
src/ai/texts/openai_response.py
@@ -128,7 +128,7 @@ async def openai_responses_api(
for sent_msg in sent_messages: # save the reponse to R2
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
await set_cf_r2(
- f"TTL/{day}d/OpenAI/{sent_msg.chat.id}/{sent_msg.id}/{model_id}/{key_hash}",
+ f"TTL/{day}d/OpenAI/{sent_msg.chat.id}/{sent_msg.id}/{key_hash}",
data=resp["full_response"],
metadata={"response_id": resp["response_id"]},
silent=silent,