Commit 713ef22

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-02-09 10:12:39
feat(gemini): add image generation with Gemini
1 parent 9864b28
Changed files (2)
src
src/ai/images/gemini.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import re
+from pathlib import Path
+from typing import Any
+
+import anyio
+from glom import glom
+from google import genai
+from google.genai import types
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from ai.texts.contexts import base64_media
+from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval
+from config import AI, DOWNLOAD_DIR, PROXY
+from messages.modify import message_modify
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import delete_message, startswith_prefix
+from utils import rand_string, strings_list
+
+
+async def gemini_image_generation(
+    client: Client,
+    message: Message,
+    *,
+    model_id: str = "",
+    model_name: str = "",
+    gemini_base_url: str = AI.GEMINI_BASE_URL,
+    gemini_api_keys: str = AI.GEMINI_API_KEYS,
+    gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
+    gemini_generate_content_config: str | dict = "",
+    gemini_proxy: str | None = PROXY.GOOGLE,
+    support_reference_images: int = 0,
+    **kwargs,
+) -> bool:
+    """Get Gemini Image Generation."""
+    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
+
+    for api_key in strings_list(gemini_api_keys, shuffle=True):
+        try:
+            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
+            gemini = genai.Client(api_key=api_key, http_options=http_options)
+            params: dict[str, Any] = {
+                "model": model_id,
+                "contents": await get_gemini_contexts(client, message, model_name, support_reference_images=support_reference_images),
+                "config": {"response_modalities": ["TEXT", "IMAGE"]},
+            }
+            if conf := literal_eval(gemini_generate_content_config):
+                params["config"] |= conf
+            image_config = glom(params, "config.image_config", default={})
+            if not image_config.get("aspect_ratio"):
+                image_config["aspect_ratio"] = infer_aspect_ratio(message.content)
+                params["config"]["image_config"] = image_config
+            logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
+            response = await gemini.aio.models.generate_content(**params)
+            parts = glom(response.model_dump(), "candidates.0.content.parts", default=[]) or []
+            texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
+            media = await download_generated_images(parts)
+            if media:
+                await send2tg(client, message, texts=f"{EMOJI_IMG_BOT}**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+                await delete_message(status_msg)
+                return True
+        except Exception as e:
+            logger.error(f"Gemini API error: {e}")
+            await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
+    return False
+
+
+async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, support_reference_images: int = 0) -> list[dict]:
+    """Generate Gemini image generation contexts.
+
+    Returns:
+        list: [
+            {"role": "user", "parts": [{"text": "prompt"}]},
+            {"role": "model", "inline_data": {"mime_type": "image/png", "data": "base64_encoded_image_data"}},
+            ]
+    """
+
+    def clean(text: str) -> str:
+        text = clean_cmd_prefix(text)
+        return text.removeprefix(f"{EMOJI_IMG_BOT}{model_name}:").lstrip()
+
+    if not support_reference_images:
+        return [{"text": clean(message.content)}]
+    messages = [message]
+    while message.reply_to_message:
+        message = message.reply_to_message
+        if not message.service:  # ignore service messages
+            messages.append(message_modify(message))
+    contents = []
+    num_img = 0
+    for m in messages:  # new to old
+        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
+        role = "model" if any(startswith_prefix(msg.content, f"{EMOJI_IMG_BOT}{model_name}:") for msg in group_messages) else "user"
+        parts = []
+        for msg in group_messages[::-1]:  # new to old
+            if prompt := clean(msg.content):
+                parts.append({"text": prompt})
+            if msg.photo and num_img < support_reference_images:
+                res = await base64_media(client, msg)
+                parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
+                num_img += 1
+        if parts:
+            contents.append({"role": role, "parts": parts[::-1]})  # old to new
+    return contents[::-1]  # old to new
+
+
+async def download_generated_images(parts: list[dict]) -> list[dict]:
+    """Download generated images.
+
+    parts: [
+        {"text": "Here's an picture of ..."},
+        {"inline_data": {"mime_type": "image/png", "data": "binary data"}}
+    ]
+
+    Return:
+    [
+        {
+            "photo": "/path/to/image.png"
+        }
+    ]
+    """
+    images = []
+    for part in parts:
+        if not part.get("inline_data"):
+            continue
+        mime_type = part["inline_data"]["mime_type"]
+        ext = mime_type.split("/")[-1]
+        data = part["inline_data"]["data"]
+        save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
+        async with await anyio.open_file(save_path, "wb") as f:
+            await f.write(data)
+        images.append({"photo": save_path.as_posix()})
+    return images
+
+
+def infer_aspect_ratio(text: str) -> str:
+    """Infer aspect ratio from text.
+
+    Args:
+        text (str): Text.
+
+    Returns:
+        str: Aspect ratio. (width:height)
+    """
+    r"""
+    (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
+    (aspect_ratio|aspect ratio|长宽比) : 匹配 "aspect_ratio", "aspect ratio" 或 "长宽比"
+    \s*                   : 匹配零个或多个空格
+    [=::]                  : 匹配等号 `=` 或冒号 `:` 或全角冒号 `:`
+    \s*                   : 匹配零个或多个空格
+    "?                    : 匹配一个可选的双引号 (0 或 1 次)
+    (\d+:\d+)             : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
+    "?                    : 匹配一个可选的双引号 (0 或 1 次)
+    """  # noqa: RUF001
+    pattern = r"(?i)(aspect_ratio|aspect ratio|长宽比)(.*?)[=::]\s*\"?(\d+:\d+)\"?"  # noqa: RUF001
+    if match := re.search(pattern, text):
+        return match.group(3)
+
+    text = text.lower()
+    if "portrait" in text:
+        return "9:16"
+    if "landscape" in text:
+        return "16:9"
+    if "square" in text:
+        return "1:1"
+    return "9:16"  # default
src/ai/main.py
@@ -5,6 +5,7 @@ import re
 from pyrogram.client import Client
 from pyrogram.types import Message
 
+from ai.images.gemini import gemini_image_generation
 from ai.images.models import get_image_model_configs
 from ai.images.openai_img import openai_image_generation
 from ai.images.post import http_post_image_generation
@@ -79,6 +80,9 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
             case "post":
                 if await http_post_image_generation(client, message, **model_config):
                     return
+            case "gemini":
+                if await gemini_image_generation(client, message, **model_config):
+                    return
 
 
 async def ai_video_generation(client: Client, message: Message, **kwargs) -> None: