main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4from pathlib import Path
5from typing import Any
6
7import anyio
8from glom import glom
9from google import genai
10from google.genai import types
11from loguru import logger
12from pyrogram.client import Client
13from pyrogram.types import Message
14
15from ai.images.utils import extract_aspect_ratio
16from ai.texts.contexts import base64_media
17from ai.utils import clean_cmd_prefix, literal_eval
18from config import AI, DOWNLOAD_DIR, PROXY
19from messages.modify import message_modify
20from messages.progress import modify_progress
21from messages.sender import send2tg
22from messages.utils import delete_message, startswith_prefix
23from utils import rand_string, strings_list
24
25
26async def gemini_image_generation(
27 client: Client,
28 message: Message,
29 *,
30 model_id: str = "",
31 model_name: str = "",
32 gemini_base_url: str = AI.GEMINI_BASE_URL,
33 gemini_api_keys: str = AI.GEMINI_API_KEYS,
34 gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
35 gemini_generate_content_config: str | dict = "",
36 gemini_proxy: str | None = PROXY.GOOGLE,
37 max_reference_img: int = 0,
38 num: int = 1,
39 **kwargs,
40) -> dict[str, Message | bool]:
41 """Get Gemini Image Generation.
42
43 Return:
44 dict[str, Message | bool]:
45 {"success": True} # Generated image successfully
46 {"progress": Message} # Failed to generate image
47 """
48 status_msg = kwargs.get("progress") or await message.reply(f"π**{model_name}**:\nζ£ε¨ηζεΎε...", quote=True)
49
50 for api_key in strings_list(gemini_api_keys, shuffle=True):
51 try:
52 http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers), timeout=120_000)
53 gemini = genai.Client(api_key=api_key, http_options=http_options)
54 params: dict[str, Any] = {
55 "model": model_id,
56 "contents": await get_gemini_contexts(client, message, model_name, max_reference_img=max_reference_img),
57 "config": {"response_modalities": ["IMAGE"]},
58 }
59 if conf := literal_eval(gemini_generate_content_config):
60 params["config"] |= conf
61 image_config = glom(params, "config.image_config", default={})
62 if not image_config.get("aspect_ratio"):
63 aspect_ratio, _ = extract_aspect_ratio(message.content)
64 image_config["aspect_ratio"] = aspect_ratio or "16:9"
65 params["config"]["image_config"] = image_config
66 logger.debug(f"genai.Client().models.generate_content(**{params})")
67 tasks = [gemini.aio.models.generate_content(**params) for _ in range(num)]
68 resp = await asyncio.gather(*tasks, return_exceptions=True)
69 blobs = glom(resp, "**.inline_data", default=[]) or []
70 media = await download_generated_images(blobs)
71 if media:
72 await send2tg(client, message, texts=f"π**{model_name}**:", media=media, caption_above=True, **kwargs)
73 await delete_message(status_msg)
74 return {"success": True}
75 await modify_progress(status_msg, text="βηζεΎεε€±θ΄₯", force_update=True, **kwargs)
76 except Exception as e:
77 logger.error(f"Gemini API error: {e}")
78 await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
79 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
80
81
82async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
83 """Generate Gemini image generation contexts.
84
85 Returns:
86 list: [
87 {"role": "user", "parts": [{"text": "prompt"}]},
88 {"role": "model", "inline_data": {"mime_type": "image/png", "data": "base64_encoded_image_data"}},
89 ]
90 """
91
92 def clean(text: str) -> str:
93 if not text:
94 return ""
95 text = clean_cmd_prefix(str(text))
96 _, text = extract_aspect_ratio(text)
97 return text.removeprefix(f"π{model_name}:").lstrip()
98
99 if not max_reference_img:
100 return [{"text": clean(message.content)}]
101 messages = [message]
102 while message.reply_to_message:
103 message = message.reply_to_message
104 if not message.service: # ignore service messages
105 messages.append(message_modify(message))
106 contents = []
107 num_img = 0
108 for m in messages: # new to old
109 group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
110 role = "model" if any(startswith_prefix(msg.content, f"π{model_name}:") for msg in group_messages) else "user"
111 parts = []
112 for msg in group_messages[::-1]: # new to old
113 if prompt := clean(msg.content):
114 parts.append({"text": prompt})
115 if msg.photo and num_img < max_reference_img:
116 res = await base64_media(client, msg)
117 parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
118 num_img += 1
119 if parts:
120 contents.append({"role": role, "parts": parts[::-1]}) # old to new
121 return contents[::-1] # old to new
122
123
124async def download_generated_images(blobs: list[types.Blob]) -> list[dict]:
125 """Download generated images.
126
127 Return:
128 [
129 {
130 "photo": "/path/to/image.png"
131 }
132 ]
133 """
134 images = []
135 for blob in blobs:
136 if not blob.data:
137 continue
138 mime_type = blob.mime_type or "image/png"
139 ext = mime_type.split("/")[-1]
140 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
141 async with await anyio.open_file(save_path, "wb") as f:
142 await f.write(blob.data)
143 images.append({"photo": save_path.as_posix()})
144 return images