main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4from pathlib import Path
  5from typing import Any
  6
  7import anyio
  8from glom import glom
  9from google import genai
 10from google.genai import types
 11from loguru import logger
 12from pyrogram.client import Client
 13from pyrogram.types import Message
 14
 15from ai.images.utils import extract_aspect_ratio
 16from ai.texts.contexts import base64_media
 17from ai.utils import clean_cmd_prefix, literal_eval
 18from config import AI, DOWNLOAD_DIR, PROXY
 19from messages.modify import message_modify
 20from messages.progress import modify_progress
 21from messages.sender import send2tg
 22from messages.utils import delete_message, startswith_prefix
 23from utils import rand_string, strings_list
 24
 25
 26async def gemini_image_generation(
 27    client: Client,
 28    message: Message,
 29    *,
 30    model_id: str = "",
 31    model_name: str = "",
 32    gemini_base_url: str = AI.GEMINI_BASE_URL,
 33    gemini_api_keys: str = AI.GEMINI_API_KEYS,
 34    gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
 35    gemini_generate_content_config: str | dict = "",
 36    gemini_proxy: str | None = PROXY.GOOGLE,
 37    max_reference_img: int = 0,
 38    num: int = 1,
 39    **kwargs,
 40) -> dict[str, Message | bool]:
 41    """Get Gemini Image Generation.
 42
 43    Return:
 44        dict[str, Message | bool]:
 45            {"success": True}  # Generated image successfully
 46            {"progress": Message}  # Failed to generate image
 47    """
 48    status_msg = kwargs.get("progress") or await message.reply(f"🍌**{model_name}**:\nζ­£εœ¨η”Ÿζˆε›Ύεƒ...", quote=True)
 49
 50    for api_key in strings_list(gemini_api_keys, shuffle=True):
 51        try:
 52            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers), timeout=120_000)
 53            gemini = genai.Client(api_key=api_key, http_options=http_options)
 54            params: dict[str, Any] = {
 55                "model": model_id,
 56                "contents": await get_gemini_contexts(client, message, model_name, max_reference_img=max_reference_img),
 57                "config": {"response_modalities": ["IMAGE"]},
 58            }
 59            if conf := literal_eval(gemini_generate_content_config):
 60                params["config"] |= conf
 61            image_config = glom(params, "config.image_config", default={})
 62            if not image_config.get("aspect_ratio"):
 63                aspect_ratio, _ = extract_aspect_ratio(message.content)
 64                image_config["aspect_ratio"] = aspect_ratio or "16:9"
 65                params["config"]["image_config"] = image_config
 66            logger.debug(f"genai.Client().models.generate_content(**{params})")
 67            tasks = [gemini.aio.models.generate_content(**params) for _ in range(num)]
 68            resp = await asyncio.gather(*tasks, return_exceptions=True)
 69            blobs = glom(resp, "**.inline_data", default=[]) or []
 70            media = await download_generated_images(blobs)
 71            if media:
 72                await send2tg(client, message, texts=f"🍌**{model_name}**:", media=media, caption_above=True, **kwargs)
 73                await delete_message(status_msg)
 74                return {"success": True}
 75            await modify_progress(status_msg, text="βŒη”Ÿζˆε›Ύεƒε€±θ΄₯", force_update=True, **kwargs)
 76        except Exception as e:
 77            logger.error(f"Gemini API error: {e}")
 78            await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
 79    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 80
 81
 82async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
 83    """Generate Gemini image generation contexts.
 84
 85    Returns:
 86        list: [
 87            {"role": "user", "parts": [{"text": "prompt"}]},
 88            {"role": "model", "inline_data": {"mime_type": "image/png", "data": "base64_encoded_image_data"}},
 89            ]
 90    """
 91
 92    def clean(text: str) -> str:
 93        if not text:
 94            return ""
 95        text = clean_cmd_prefix(str(text))
 96        _, text = extract_aspect_ratio(text)
 97        return text.removeprefix(f"🍌{model_name}:").lstrip()
 98
 99    if not max_reference_img:
100        return [{"text": clean(message.content)}]
101    messages = [message]
102    while message.reply_to_message:
103        message = message.reply_to_message
104        if not message.service:  # ignore service messages
105            messages.append(message_modify(message))
106    contents = []
107    num_img = 0
108    for m in messages:  # new to old
109        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
110        role = "model" if any(startswith_prefix(msg.content, f"🍌{model_name}:") for msg in group_messages) else "user"
111        parts = []
112        for msg in group_messages[::-1]:  # new to old
113            if prompt := clean(msg.content):
114                parts.append({"text": prompt})
115            if msg.photo and num_img < max_reference_img:
116                res = await base64_media(client, msg)
117                parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
118                num_img += 1
119        if parts:
120            contents.append({"role": role, "parts": parts[::-1]})  # old to new
121    return contents[::-1]  # old to new
122
123
124async def download_generated_images(blobs: list[types.Blob]) -> list[dict]:
125    """Download generated images.
126
127    Return:
128    [
129        {
130            "photo": "/path/to/image.png"
131        }
132    ]
133    """
134    images = []
135    for blob in blobs:
136        if not blob.data:
137            continue
138        mime_type = blob.mime_type or "image/png"
139        ext = mime_type.split("/")[-1]
140        save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
141        async with await anyio.open_file(save_path, "wb") as f:
142            await f.write(blob.data)
143        images.append({"photo": save_path.as_posix()})
144    return images