Commit 713ef22
src/ai/images/gemini.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import re
+from pathlib import Path
+from typing import Any
+
+import anyio
+from glom import glom
+from google import genai
+from google.genai import types
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from ai.texts.contexts import base64_media
+from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval
+from config import AI, DOWNLOAD_DIR, PROXY
+from messages.modify import message_modify
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import delete_message, startswith_prefix
+from utils import rand_string, strings_list
+
+
+async def gemini_image_generation(
+ client: Client,
+ message: Message,
+ *,
+ model_id: str = "",
+ model_name: str = "",
+ gemini_base_url: str = AI.GEMINI_BASE_URL,
+ gemini_api_keys: str = AI.GEMINI_API_KEYS,
+ gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
+ gemini_generate_content_config: str | dict = "",
+ gemini_proxy: str | None = PROXY.GOOGLE,
+ support_reference_images: int = 0,
+ **kwargs,
+) -> bool:
+ """Get Gemini Image Generation."""
+ status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
+
+ for api_key in strings_list(gemini_api_keys, shuffle=True):
+ try:
+ http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
+ gemini = genai.Client(api_key=api_key, http_options=http_options)
+ params: dict[str, Any] = {
+ "model": model_id,
+ "contents": await get_gemini_contexts(client, message, model_name, support_reference_images=support_reference_images),
+ "config": {"response_modalities": ["TEXT", "IMAGE"]},
+ }
+ if conf := literal_eval(gemini_generate_content_config):
+ params["config"] |= conf
+ image_config = glom(params, "config.image_config", default={})
+ if not image_config.get("aspect_ratio"):
+ image_config["aspect_ratio"] = infer_aspect_ratio(message.content)
+ params["config"]["image_config"] = image_config
+ logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
+ response = await gemini.aio.models.generate_content(**params)
+ parts = glom(response.model_dump(), "candidates.0.content.parts", default=[]) or []
+ texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
+ media = await download_generated_images(parts)
+ if media:
+ await send2tg(client, message, texts=f"{EMOJI_IMG_BOT}**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+ await delete_message(status_msg)
+ return True
+ except Exception as e:
+ logger.error(f"Gemini API error: {e}")
+ await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
+ return False
+
+
+async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, support_reference_images: int = 0) -> list[dict]:
+ """Generate Gemini image generation contexts.
+
+ Returns:
+ list: [
+ {"role": "user", "parts": [{"text": "prompt"}]},
+ {"role": "model", "inline_data": {"mime_type": "image/png", "data": "base64_encoded_image_data"}},
+ ]
+ """
+
+ def clean(text: str) -> str:
+ text = clean_cmd_prefix(text)
+ return text.removeprefix(f"{EMOJI_IMG_BOT}{model_name}:").lstrip()
+
+ if not support_reference_images:
+ return [{"text": clean(message.content)}]
+ messages = [message]
+ while message.reply_to_message:
+ message = message.reply_to_message
+ if not message.service: # ignore service messages
+ messages.append(message_modify(message))
+ contents = []
+ num_img = 0
+ for m in messages: # new to old
+ group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
+ role = "model" if any(startswith_prefix(msg.content, f"{EMOJI_IMG_BOT}{model_name}:") for msg in group_messages) else "user"
+ parts = []
+ for msg in group_messages[::-1]: # new to old
+ if prompt := clean(msg.content):
+ parts.append({"text": prompt})
+ if msg.photo and num_img < support_reference_images:
+ res = await base64_media(client, msg)
+ parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
+ num_img += 1
+ if parts:
+ contents.append({"role": role, "parts": parts[::-1]}) # old to new
+ return contents[::-1] # old to new
+
+
+async def download_generated_images(parts: list[dict]) -> list[dict]:
+ """Download generated images.
+
+ parts: [
+ {"text": "Here's an picture of ..."},
+ {"inline_data": {"mime_type": "image/png", "data": "binary data"}}
+ ]
+
+ Return:
+ [
+ {
+ "photo": "/path/to/image.png"
+ }
+ ]
+ """
+ images = []
+ for part in parts:
+ if not part.get("inline_data"):
+ continue
+ mime_type = part["inline_data"]["mime_type"]
+ ext = mime_type.split("/")[-1]
+ data = part["inline_data"]["data"]
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(data)
+ images.append({"photo": save_path.as_posix()})
+ return images
+
+
+def infer_aspect_ratio(text: str) -> str:
+ """Infer aspect ratio from text.
+
+ Args:
+ text (str): Text.
+
+ Returns:
+ str: Aspect ratio. (width:height)
+ """
+ r"""
+ (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
+ (aspect_ratio|aspect ratio|长宽比) : 匹配 "aspect_ratio", "aspect ratio" 或 "长宽比"
+ \s* : 匹配零个或多个空格
+ [=::] : 匹配等号 `=` 或冒号 `:` 或全角冒号 `:`
+ \s* : 匹配零个或多个空格
+ "? : 匹配一个可选的双引号 (0 或 1 次)
+ (\d+:\d+) : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
+ "? : 匹配一个可选的双引号 (0 或 1 次)
+ """ # noqa: RUF001
+ pattern = r"(?i)(aspect_ratio|aspect ratio|长宽比)(.*?)[=::]\s*\"?(\d+:\d+)\"?" # noqa: RUF001
+ if match := re.search(pattern, text):
+ return match.group(3)
+
+ text = text.lower()
+ if "portrait" in text:
+ return "9:16"
+ if "landscape" in text:
+ return "16:9"
+ if "square" in text:
+ return "1:1"
+ return "9:16" # default
src/ai/main.py
@@ -5,6 +5,7 @@ import re
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.images.gemini import gemini_image_generation
from ai.images.models import get_image_model_configs
from ai.images.openai_img import openai_image_generation
from ai.images.post import http_post_image_generation
@@ -79,6 +80,9 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
case "post":
if await http_post_image_generation(client, message, **model_config):
return
+ case "gemini":
+ if await gemini_image_generation(client, message, **model_config):
+ return
async def ai_video_generation(client: Client, message: Message, **kwargs) -> None: