Commit 5a65136

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-04-27 05:04:01
feat(ai): support openai image generation
1 parent fa4093b
Changed files (3)
src/ai/images/openai_img.py
@@ -5,20 +5,16 @@ from pathlib import Path
 from typing import Literal
 
 import anyio
-from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 from pyrogram.client import Client
-from pyrogram.types import Message
+from pyrogram.types import InputMediaPhoto, Message
 
 from ai.images.post import get_image_contexts
 from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
 from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
 from config import DOWNLOAD_DIR, PROXY
 from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import delete_message
-from networking import download_file
 from utils import rand_string, strings_list
 
 
@@ -37,6 +33,7 @@ async def openai_image_generation(
     max_width: int = int(1e16),
     max_height: int = int(1e16),
     max_size: int = int(1e32),
+    api_provider: Literal["openai", "seedream"] = "openai",
     **kwargs,
 ) -> dict[str, Message | bool]:
     """Get OpenAI Image Generation.
@@ -52,12 +49,17 @@ async def openai_image_generation(
         if literal_eval(client_config):
             openai_client |= literal_eval(client_config)
         if proxy:
-            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
-        prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
+            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy, timeout=600)}
+        prompt, reference_images = await get_image_contexts(
+            client,
+            message,
+            max_reference_img=max_reference_img,
+            img_type="str" if api_provider == "openai" else "base64",
+        )
         if not prompt:
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
             return {"progress": status_msg} if isinstance(status_msg, Message) else {}
-        params = {}
+        params: dict = {"stream": True, "timeout": 600}
         if literal_eval(generate_config):
             params |= literal_eval(generate_config)
         aspect_ratio, _ = extract_aspect_ratio(message.content)
@@ -65,7 +67,10 @@ async def openai_image_generation(
             width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
             params |= {"size": f"{width}x{height}"}
         if reference_images:
-            params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
+            if api_provider == "seedream":
+                params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
+            else:
+                params["image"] = [Path(x) for x in reference_images]
         params |= {"model": model_id, "prompt": prompt}
         logger.debug(f"openai.images.generate(**{prettify(params)})")
     except Exception as e:
@@ -78,65 +83,28 @@ async def openai_image_generation(
             openai_client["api_key"] = api_key
             logger.trace(f"AsyncOpenAI(**{openai_client})")
             openai = AsyncOpenAI(**openai_client)
-            resp = await openai.images.generate(**params)
-            resp = trim_none(resp.model_dump())
-            if images := await download_generated_images(resp, proxy=proxy):
-                resp.pop("data", None)
-                caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}\n"
-                for idx, img in enumerate([x for x in images if x.get("url")]):
-                    caption += f"[P{idx + 1}原图]({img['url']})"
-                await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
-                await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
-                await delete_message(status_msg)
+            if reference_images and api_provider == "openai":
+                params.pop("moderation", None)
+                func = openai.images.edit
+            else:
+                func = openai.images.generate
+            async for chunk in await func(**params):
+                resp = trim_none(chunk.model_dump())
+                if b64_json := resp.pop("b64_json", None):
+                    image_bytes = base64.b64decode(b64_json)
+                    save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{params.get('output_format', 'jpeg')}"
+                    async with await anyio.open_file(save_path, "wb") as f:
+                        await f.write(image_bytes)
+                    rtype = resp.pop("type", None)
+                    caption = f"{EMOJI_IMG_BOT}**{model_name}**\n" if rtype == "image_generation.completed" else f"⌛️**{model_name}** 中间结果...\n"
+                    status_msg = await status_msg.edit_media(InputMediaPhoto(str(save_path), caption=caption + prettify(resp)))
+                    save_path.unlink(missing_ok=True)
+                elif resp.pop("type", None) == "image_generation.completed":
+                    await status_msg.edit_caption(f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}")
+                logger.trace(resp)
+            if status_msg.photo:
                 return {"success": True}
         except Exception as e:
             logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
             await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
     return {"progress": status_msg} if isinstance(status_msg, Message) else {}
-
-
-async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
-    """Download generated images.
-
-    Response: {
-        "model": "doubao-seedream-4-5-251128",
-        "created": 1757321139,
-        "data": [
-            {
-                "url": "https://...",
-                "size": "3104x1312"
-            },
-            {
-                "b64_json": "/9j/4AAQSkZJRgABA...",
-                "size": "3104x1312"
-            }
-        ],
-        "usage": {
-            "generated_images": 2,
-            "output_tokens": xxx,
-            "total_tokens": xxx
-        }
-    }
-
-    Return:
-    [
-        {
-            "path": "/path/to/image.png"
-            "url": "https://...",
-        }
-    ]
-    """
-    results = []
-    data = glom(response, "data", default=[]) or []
-    for item in data:
-        if url := item.get("url"):
-            img_path = await download_file(url, proxy=proxy)
-            if Path(img_path).is_file():
-                results.append({"path": img_path, "url": url})
-        if b64_json := item.get("b64_json"):
-            image_bytes = base64.b64decode(b64_json)
-            save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
-            async with await anyio.open_file(save_path, "wb") as f:
-                await f.write(image_bytes)
-            results.append({"path": save_path.as_posix()})
-    return results
src/ai/images/post.py
@@ -112,7 +112,13 @@ def extract_metadata(response: dict) -> str:
     return ""
 
 
-async def get_image_contexts(client: Client, message: Message, *, max_reference_img: int = 0) -> tuple[str, list[str]]:
+async def get_image_contexts(
+    client: Client,
+    message: Message,
+    *,
+    max_reference_img: int = 0,
+    img_type: Literal["base64", "str"] = "base64",
+) -> tuple[str, list[str]]:
     """Get image generation contexts.
 
     Returns:
@@ -142,8 +148,13 @@ async def get_image_contexts(client: Client, message: Message, *, max_reference_
             prompt = clean(msg.content)
             if not msg.photo or len(images) >= max_reference_img:
                 continue
-            res = await base64_media(client, msg)
-            images.append(f"data:image/{res['ext']};base64,{res['base64']}")
+            if img_type == "base64":
+                res = await base64_media(client, msg)
+                images.append(f"data:image/{res['ext']};base64,{res['base64']}")
+            else:
+                path = await client.download_media(msg)
+                if Path(str(path)).is_file():
+                    images.append(Path(str(path)).as_posix())
     return prompt, images
 
 
src/ai/utils.py
@@ -18,7 +18,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM
 
 from config import AI, PREFIX, PROXY, cache
 from database.kv import get_cf_kv
-from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
+from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, ts_to_dt, zhcn
 
 # ruff: noqa: RUF001
 EMOJI_TEXT_BOT = "🤖"
@@ -66,6 +66,10 @@ def trim_none(obj: dict) -> dict:
 def prettify(data: dict) -> str:
     with contextlib.suppress(Exception):
         data = trim_none(data)
+        if isinstance(data.get("created"), int):
+            data["created"] = ts_to_dt(data["created"]).strftime("%Y-%m-%d %H:%M:%S")  # ty:ignore[unresolved-attribute]
+        if isinstance(data.get("created_at"), int):
+            data["created_at"] = ts_to_dt(data["created_at"]).strftime("%Y-%m-%d %H:%M:%S")  # ty:ignore[unresolved-attribute]
         return json.dumps(data, ensure_ascii=False, indent=2)
     return str(data)