Commit a91f1bf

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-27 09:21:12
feat(gemini): support video and audio message
1 parent 5684aec
src/llm/contexts.py
@@ -1,12 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import base64
 import contextlib
 from pathlib import Path
 from typing import TYPE_CHECKING
 
-from google.genai.types import Part
+from google import genai
+from google.genai.types import FileState, Part, UploadFileConfig
 from loguru import logger
+from openai import AsyncOpenAI
 from pyrogram.client import Client
 from pyrogram.types import Message
 
@@ -28,7 +31,12 @@ def get_conversations(message: Message) -> list[Message]:
     return messages
 
 
-async def get_conversation_contexts(client: Client, conversations: list[Message], ctx_format: str = "openai") -> list[dict]:
+async def get_conversation_contexts(
+    client: Client,
+    conversations: list[Message],
+    ctx_format: str = "openai",
+    app: genai.Client | AsyncOpenAI | None = None,
+) -> list[dict]:
     """Generate contexts for GPT conversation.
 
     From old to new messages.
@@ -38,7 +46,7 @@ async def get_conversation_contexts(client: Client, conversations: list[Message]
         contexts = [await single_gpt_context(client, message) for message in conversations]
         contexts = [x for x in contexts if x.get("content")]
     else:
-        contexts = [await single_gemini_context(client, message) for message in conversations]
+        contexts = [await single_gemini_context(client, message, app) for message in conversations]  # type: ignore
         contexts = [x for x in contexts if x.get("parts")]
 
     return contexts[: int(GPT.HISTORY_CONTEXT)]
@@ -64,7 +72,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
         return {}
 
     extra_txt_extensions = [".sh", ".json", ".xml"]  # treat these as txt file
-    extra_markdown_extensions = [".pdf", ".html", ".doc", ".docx", ".ppt", ".pptx", ".xls", ".xlsx"]  # convert to markdown
+    extra_markdown_extensions = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx"]  # convert to markdown
 
     messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
     contexts = []
@@ -84,7 +92,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
                         }
                     )
                 if Path(info["file_name"]).suffix in extra_markdown_extensions:
-                    fpath: str = await client.download_media(message)  # type: ignore
+                    fpath: str = await client.download_media(msg)  # type: ignore
                     text = convert_md(fpath)
                     Path(fpath).unlink(missing_ok=True)
                     contexts.append(
@@ -105,7 +113,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
     return {"role": role, "content": contexts} if contexts else {}
 
 
-async def single_gemini_context(client: Client, message: Message) -> dict:
+async def single_gemini_context(client: Client, message: Message, app: genai.Client) -> dict:
     """Generate Gemini contexts for a single message (Without considering reply message).
 
     Returns:
@@ -121,30 +129,37 @@ async def single_gemini_context(client: Client, message: Message) -> dict:
     role = "model" if BOT_TIPS in info["text"] else "user"
     if info["mtype"] not in ["text", "photo", "audio", "voice", "video", "document"]:
         return {}
-    extra_mime_types = ["application/pdf", "application/x-javascript"]  # gemini has built-in support for these
-    extra_txt_extensions = [".sh", ".json", ".xml"]  # also treat these as txt file
-    extra_markdown_extensions = [".html", ".doc", ".docx", ".ppt", ".pptx", ".xls", ".xlsx"]  # convert to markdown
+    # gemini has built-in support for these extensions
+    gemini_extensions = [".pdf", ".js", ".py", ".txt", ".html", ".css", ".md", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
+    # gemini has built-in support for these mime types
+    gemini_mime_types = ["application/pdf", "application/x-javascript", "audio/ogg", "audio/mp4", "image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"]
+    extra_txt_extensions = [".sh", ".json"]  # also treat these as txt file
+    extra_markdown_extensions = [".docx", ".pptx", ".xls", ".xlsx"]  # convert to markdown
 
     messages = await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message]
     parts = []
     for msg in messages:
         info = parse_msg(msg, silent=True)
         try:
-            if info["mtype"] == "photo":
-                res = await base64_media(client, msg)
-                parts.append(Part.from_bytes(mime_type=f"image/{res['ext']}", data=res["base64"]))
+            if info["mtype"] in ["video", "photo", "audio", "voice"] or info["mime_type"] in gemini_mime_types or any(info["file_name"].endswith(ext) for ext in gemini_extensions):
+                fpath: str = await client.download_media(msg, in_memory=False)  # type: ignore  # type: ignore
+                upload = await app.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=info["file_name"] or f"send from {info['full_name']}"))
+                while upload.state == FileState.PROCESSING:
+                    logger.trace("Waiting for upload to complete...")
+                    await asyncio.sleep(1)
+                    upload = await app.aio.files.get(name=upload.name)  # type: ignore
+                if upload.state == FileState.ACTIVE and upload.uri:
+                    parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
+                Path(fpath).unlink(missing_ok=True)
             elif info["mtype"] == "document":
                 if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
                     res = await base64_media(client, msg)
-                    parts.append(Part.from_text(text=f"[fileowner]: {info['full_name']}\n[filename]: {info['file_name']}\n[file content]:\n{res['value'].strip()}"))
-                if info["mime_type"] in extra_mime_types:
-                    data: BytesIO = await client.download_media(message, in_memory=True)  # type: ignore
-                    parts.append(Part.from_bytes(mime_type=info["mime_type"], data=bytes(data.getbuffer())))
+                    parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{res['value'].strip()}"))
                 if Path(info["file_name"]).suffix in extra_markdown_extensions:
-                    fpath: str = await client.download_media(message)  # type: ignore
+                    fpath: str = await client.download_media(msg)  # type: ignore
                     text = convert_md(fpath)
                     Path(fpath).unlink(missing_ok=True)
-                    parts.append(Part.from_text(text=f"[fileowner]: {info['full_name']}\n[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
+                    parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
             # user message has entity urls, use full html
             clean_texts = clean_context(info["html"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
             if not clean_texts:
src/llm/gemini.py
@@ -83,8 +83,6 @@ async def gemini_response(
         msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
-        contexts = await get_conversation_contexts(client, conversations, ctx_format="gemini")
-        gemini_logging(contexts)
         genconfig |= {"response_modalities": response_modalities}
         if tools:
             genconfig |= {"tools": tools}
@@ -93,7 +91,7 @@ async def gemini_response(
         if thinking_budget is not None and not disable_thinking:
             thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
             genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
-        params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
+        params = {"model": model, "conversations": conversations, "config": GenerateContentConfig(**genconfig)}
         logger.trace(params)
         if modality == "image":
             return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
@@ -144,6 +142,10 @@ async def gemini_stream(
         http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
+        # Construct the request params
+        if "contents" not in params and "conversations" in params:  # convert conversations to contents
+            params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
+        gemini_logging(params["contents"])
         sent_messages = []
         is_reasoning = False
         is_reasoning_conversation = None  # to  indicate whether it is a reasoning conversation
@@ -263,6 +265,9 @@ async def gemini_nonstream(
         http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
+        # Construct the request params
+        if "contents" not in params and "conversations" in params:  # convert conversations to contents
+            params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini")
         response = await app.aio.models.generate_content(**params)
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
         res = parse_response(response.model_dump(), append_grounding=append_grounding)
src/llm/gpt.py
@@ -22,23 +22,22 @@ from messages.utils import count_without_entities, equal_prefix, startswith_pref
 HELP = f"""🤖**GPT对话**
 `{PREFIX.GPT}` 后接提示词即可与GPT对话
 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
-暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
+暂不支持视频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
 
 ⚙️模型配置:
-`{PREFIX.GPT}`默认使用 **{GPT.DEFAULT_PROVIDER.lower()}** 模型
+`{PREFIX.GPT}`: 默认使用 **{GPT.DEFAULT_PROVIDER.lower()}** 模型
 
 🔄使用以下命令强制切换模型:
 `/gpt`: **{GPT.OPENAI_MODEL_NAME}** {image_emoji(GPT.OPENAI_ACCEPT_IMAGE)}
-`/gemini`: **{GPT.GEMINI_MODEL_NAME}** {image_emoji(GPT.GEMINI_ACCEPT_IMAGE)}
+`/gemini`: **{GPT.GEMINI_MODEL_NAME}** 🎬🎧{image_emoji(GPT.GEMINI_ACCEPT_IMAGE)}
 `/ds`: **{GPT.DEEPSEEK_MODEL_NAME}** {image_emoji(GPT.DEEPSEEK_ACCEPT_IMAGE)}
 `/qwen`: **{GPT.QWEN_MODEL_NAME}** {image_emoji(GPT.QWEN_ACCEPT_IMAGE)}
 `/doubao`: **{GPT.DOUBAO_MODEL_NAME}** {image_emoji(GPT.DOUBAO_ACCEPT_IMAGE)}
 `/grok`: **{GPT.GROK_MODEL_NAME}** {image_emoji(GPT.GROK_ACCEPT_IMAGE)}
 
 ⚠️注意:
-若对话历史包含图片
-但模型不支持图片(无🏞图标)
-会自动切换为 **{GPT.OMNI_PROVIDER.lower()}** 模型
+若对话历史包含图片, 但模型不支持图片 (无🏞图标), 会自动切换为 **{GPT.OMNI_PROVIDER.lower()}** 模型
+若对话历史包含视频/音频, 但模型不支持视频/音频 (无🎬/🎧图标), 会自动切换为 **{GPT.GEMINI_MODEL_NAME}** 模型
 """
 
 
@@ -104,7 +103,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     kwargs["message_info"] = info  # save trigger message info
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)  # {"type": "text", "error": None}  # text, image
-    model_id, resp_modality, sdk = get_model_id(info["text"], reply_text, context_type["type"])
+    model_id, resp_modality, sdk = get_model_id(info["text"], reply_text, context_type)
     if "gemini" in model_id.lower() and sdk == "gemini":
         return await gemini_response(client, message, conversations, resp_modality, **kwargs)
 
@@ -120,9 +119,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
-    if context_type.get("error"):
-        logger.warning(context_type["error"])
-        await modify_progress(message=status_msg, text=f"{msg}\n{context_type['error']}", force_update=True, **kwargs)
     config, response = await merge_tools_response(config, **kwargs)
     # skip send a new request if tool_model is the same as the current model
     if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
src/llm/models.py
@@ -10,22 +10,16 @@ from messages.parser import parse_msg
 from messages.utils import startswith_prefix
 
 
-def get_context_type(conversations: list[Message]) -> dict:
+def get_context_type(conversations: list[Message]) -> str:
     """Get model type based on conversation messages."""
-    has_video = False
-    has_audio = False
-    res = {"type": "text"}
+    context_type = "text"
     for message in conversations:
         info = parse_msg(message, silent=True)
         if info["mtype"] == "photo":
-            res["type"] = "image"
-        if info["mtype"] == "video":
-            has_video = True
-        if info["mtype"] == "audio":
-            has_audio = True
-    if has_audio or has_video:
-        res["error"] = f"⚠️已忽略上下文中的视频/音频消息\n可以先用 `{PREFIX.ASR}` 命令转为文字后再使用AI功能"
-    return res
+            context_type = "image"
+        if info["mtype"] in ["video", "audio", "voice"]:
+            context_type = "gemini"  # only Gemini supports audio/video
+    return context_type
 
 
 def get_model_id(text: str, reply_text: str, context_type: str) -> tuple[str, str, str]:
@@ -101,6 +95,10 @@ def get_model_id(text: str, reply_text: str, context_type: str) -> tuple[str, st
         sdk = "openai"
     if model_id and context_type == "text":  # no need to fallback if context type is text
         return model_id, response_modality, sdk
+
+    if context_type == "gemini":  # force gemini
+        return GPT.GEMINI_MODEL, "text", "gemini"
+
     if (
         (model_id == GPT.OPENAI_MODEL and not GPT.OPENAI_ACCEPT_IMAGE)
         or (model_id == GPT.GEMINI_MODEL and not GPT.GEMINI_ACCEPT_IMAGE)
src/llm/utils.py
@@ -3,16 +3,19 @@
 import random
 import re
 import tempfile
+from datetime import datetime
 from pathlib import Path
 
 import markdown
 import tiktoken
+from google import genai
+from google.genai.types import HttpOptions
 from loguru import logger
 from markitdown import MarkItDown
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 
-from config import DOWNLOAD_DIR, GPT, PREFIX
-from utils import number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound, zhcn
+from config import DOWNLOAD_DIR, GEMINI, GPT, PREFIX, cache
+from utils import nowdt, number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound, zhcn
 
 BOT_TIPS = "(回复以继续)"  # noqa: RUF001
 REASONING_BEGIN = "🤔"  # use emoji to separate model reasoning and content
@@ -171,7 +174,7 @@ def add_search_results_to_response(search_results: list[dict], response: str) ->
 
 def image_emoji(capability: bool) -> str:  # noqa: FBT001
     """Get image capability emoji."""
-    return "(🏞)" if capability else ""
+    return "🏞" if capability else ""
 
 
 def clean_cmd_prefix(text: str) -> str:
@@ -287,3 +290,23 @@ def sample_key(keys: str | list[str]) -> str:
     if not keys:
         return ""
     return random.choice(keys)
+
+
+@cache.memoize(ttl=1800)
+async def clean_gemini_files():
+    """Clean Gemini files.
+
+    Gemini allows only 20 GB of data.
+    Clean every half an hour.
+    """
+    if GEMINI.CLEAN_FILES_AFTER_SECONDS >= 48 * 3600:
+        return
+    now = nowdt()
+    for api_key in [x.strip() for x in GEMINI.API_KEY.split(",") if x.strip()]:
+        app = genai.Client(api_key=api_key, http_options=HttpOptions(async_client_args={"proxy": GEMINI.PROXY}))
+        for f in await app.aio.files.list():
+            if isinstance(f.update_time, datetime) and isinstance(f.name, str):
+                delta = now - f.update_time
+                if delta.total_seconds() > GEMINI.CLEAN_FILES_AFTER_SECONDS:
+                    logger.debug(f"Delete Gemini file: {f.name}")
+                    await app.aio.files.delete(name=f.name)
src/config.py
@@ -301,6 +301,7 @@ class GEMINI:  # Official Gemini
     PROXY = os.getenv("GEMINI_PROXY", None)
     PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "")  # Set a prefer response language for Gemini
     MAX_THINKING_BUDGET = int(os.getenv("GEMINI_MAX_THINKING_BUDGET", "24576"))  # 24K
+    CLEAN_FILES_AFTER_SECONDS = int(os.getenv("GEMINI_CLEAN_FILES_AFTER_SECONDS", "172800"))  # default to 48 hours
 
     # response modality: text
     TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-flash-preview-05-20")
src/main.py
@@ -24,6 +24,7 @@ from bridge.social import forward_social_media_results
 from config import DAILY_MESSAGES, DEVICE_NAME, ENABLE, PROXY, TOKEN, TZ, cache
 from handler import handle_social_media, handle_utilities
 from llm.summary import daily_summary
+from llm.utils import clean_gemini_files
 from messages.parser import parse_msg
 from others.podcast import summary_pods
 from permission import check_permission
@@ -126,6 +127,7 @@ async def scheduling(client: Client):
                 logger.info(f"Sending daily message to {chat_id}: {msg}")
                 await client.send_message(to_int(chat_id), msg)
     await summary_pods(client)
+    await clean_gemini_files()
 
 
 if __name__ == "__main__":