Commit 8b70eae
Changed files (1)
src
ai
images
src/ai/images/gemini.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
from pathlib import Path
from typing import Any
@@ -34,6 +35,7 @@ async def gemini_image_generation(
gemini_generate_content_config: str | dict = "",
gemini_proxy: str | None = PROXY.GOOGLE,
max_reference_img: int = 0,
+ num: int = 1,
**kwargs,
) -> dict[str, Message | bool]:
"""Get Gemini Image Generation.
@@ -47,7 +49,7 @@ async def gemini_image_generation(
for api_key in strings_list(gemini_api_keys, shuffle=True):
try:
- http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
+ http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers), timeout=120_000)
gemini = genai.Client(api_key=api_key, http_options=http_options)
params: dict[str, Any] = {
"model": model_id,
@@ -62,15 +64,15 @@ async def gemini_image_generation(
image_config["aspect_ratio"] = aspect_ratio or "16:9"
params["config"]["image_config"] = image_config
logger.debug(f"genai.Client().models.generate_content(**{params})")
- response = await gemini.aio.models.generate_content(**params)
- parts = glom(response.model_dump(), "candidates.0.content.parts", default=[]) or []
- texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
- media = await download_generated_images(parts)
+ tasks = [gemini.aio.models.generate_content(**params) for _ in range(num)]
+ resp = await asyncio.gather(*tasks, return_exceptions=True)
+ blobs = glom(resp, "**.inline_data", default=[]) or []
+ media = await download_generated_images(blobs)
if media:
- await send2tg(client, message, texts=f"π**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+ await send2tg(client, message, texts=f"π**{model_name}**:", media=media, caption_above=True, **kwargs)
await delete_message(status_msg)
return {"success": True}
- await modify_progress(status_msg, text=f"β{response.model_dump()}", force_update=True, **kwargs)
+ await modify_progress(status_msg, text="βηζεΎεε€±θ΄₯", force_update=True, **kwargs)
except Exception as e:
logger.error(f"Gemini API error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
@@ -119,14 +121,9 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
return contents[::-1] # old to new
-async def download_generated_images(parts: list[dict]) -> list[dict]:
+async def download_generated_images(blobs: list[types.Blob]) -> list[dict]:
"""Download generated images.
- parts: [
- {"text": "Here's an picture of ..."},
- {"inline_data": {"mime_type": "image/png", "data": "binary data"}}
- ]
-
Return:
[
{
@@ -135,14 +132,13 @@ async def download_generated_images(parts: list[dict]) -> list[dict]:
]
"""
images = []
- for part in parts:
- if not part.get("inline_data"):
+ for blob in blobs:
+ if not blob.data:
continue
- mime_type = part["inline_data"]["mime_type"]
+ mime_type = blob.mime_type or "image/png"
ext = mime_type.split("/")[-1]
- data = part["inline_data"]["data"]
save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
async with await anyio.open_file(save_path, "wb") as f:
- await f.write(data)
+ await f.write(blob.data)
images.append({"photo": save_path.as_posix()})
return images