Commit 8b70eae

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-03-31 09:42:45
feat(gemini): support generating multiple images
1 parent 0c3de72
Changed files (1)
src
ai
images
src/ai/images/gemini.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 from pathlib import Path
 from typing import Any
 
@@ -34,6 +35,7 @@ async def gemini_image_generation(
     gemini_generate_content_config: str | dict = "",
     gemini_proxy: str | None = PROXY.GOOGLE,
     max_reference_img: int = 0,
+    num: int = 1,
     **kwargs,
 ) -> dict[str, Message | bool]:
     """Get Gemini Image Generation.
@@ -47,7 +49,7 @@ async def gemini_image_generation(
 
     for api_key in strings_list(gemini_api_keys, shuffle=True):
         try:
-            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
+            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers), timeout=120_000)
             gemini = genai.Client(api_key=api_key, http_options=http_options)
             params: dict[str, Any] = {
                 "model": model_id,
@@ -62,15 +64,15 @@ async def gemini_image_generation(
                 image_config["aspect_ratio"] = aspect_ratio or "16:9"
                 params["config"]["image_config"] = image_config
             logger.debug(f"genai.Client().models.generate_content(**{params})")
-            response = await gemini.aio.models.generate_content(**params)
-            parts = glom(response.model_dump(), "candidates.0.content.parts", default=[]) or []
-            texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
-            media = await download_generated_images(parts)
+            tasks = [gemini.aio.models.generate_content(**params) for _ in range(num)]
+            resp = await asyncio.gather(*tasks, return_exceptions=True)
+            blobs = glom(resp, "**.inline_data", default=[]) or []
+            media = await download_generated_images(blobs)
             if media:
-                await send2tg(client, message, texts=f"🍌**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+                await send2tg(client, message, texts=f"🍌**{model_name}**:", media=media, caption_above=True, **kwargs)
                 await delete_message(status_msg)
                 return {"success": True}
-            await modify_progress(status_msg, text=f"❌{response.model_dump()}", force_update=True, **kwargs)
+            await modify_progress(status_msg, text="βŒη”Ÿζˆε›Ύεƒε€±θ΄₯", force_update=True, **kwargs)
         except Exception as e:
             logger.error(f"Gemini API error: {e}")
             await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
@@ -119,14 +121,9 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
     return contents[::-1]  # old to new
 
 
-async def download_generated_images(parts: list[dict]) -> list[dict]:
+async def download_generated_images(blobs: list[types.Blob]) -> list[dict]:
     """Download generated images.
 
-    parts: [
-        {"text": "Here's an picture of ..."},
-        {"inline_data": {"mime_type": "image/png", "data": "binary data"}}
-    ]
-
     Return:
     [
         {
@@ -135,14 +132,13 @@ async def download_generated_images(parts: list[dict]) -> list[dict]:
     ]
     """
     images = []
-    for part in parts:
-        if not part.get("inline_data"):
+    for blob in blobs:
+        if not blob.data:
             continue
-        mime_type = part["inline_data"]["mime_type"]
+        mime_type = blob.mime_type or "image/png"
         ext = mime_type.split("/")[-1]
-        data = part["inline_data"]["data"]
         save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{ext}"
         async with await anyio.open_file(save_path, "wb") as f:
-            await f.write(data)
+            await f.write(blob.data)
         images.append({"photo": save_path.as_posix()})
     return images