main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import base64
  4from pathlib import Path
  5from typing import Literal
  6
  7import anyio
  8from loguru import logger
  9from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 10from pyrogram.client import Client
 11from pyrogram.types import InputMediaPhoto, Message
 12
 13from ai.images.post import get_image_contexts
 14from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
 15from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
 16from config import DOWNLOAD_DIR, PROXY
 17from messages.progress import modify_progress
 18from utils import rand_string, strings_list
 19
 20
 21async def openai_image_generation(
 22    client: Client,
 23    message: Message,
 24    *,
 25    model_id: str = "",
 26    model_name: str = "",
 27    api_keys: str = "",
 28    client_config: str | dict = "",
 29    generate_config: str | dict = "",
 30    proxy: str | None = PROXY.OPENAI,
 31    max_reference_img: int = 0,
 32    resolution: Literal["1K", "2K", "4K"] = "1K",
 33    max_width: int = int(1e16),
 34    max_height: int = int(1e16),
 35    max_size: int = int(1e32),
 36    api_provider: Literal["openai", "seedream"] = "openai",
 37    **kwargs,
 38) -> dict[str, Message | bool]:
 39    """Get OpenAI Image Generation.
 40
 41    Return:
 42    dict[str, Message | bool]:
 43        {"success": True}  # Generated image successfully
 44        {"progress": Message}  # Failed to generate image
 45    """
 46    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
 47    try:
 48        openai_client = {}
 49        if literal_eval(client_config):
 50            openai_client |= literal_eval(client_config)
 51        if proxy:
 52            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy, timeout=600)}
 53        prompt, reference_images = await get_image_contexts(
 54            client,
 55            message,
 56            max_reference_img=max_reference_img,
 57            img_type="str" if api_provider == "openai" else "base64",
 58        )
 59        if not prompt:
 60            await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
 61            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 62        params: dict = {"stream": True, "timeout": 600}
 63        if literal_eval(generate_config):
 64            params |= literal_eval(generate_config)
 65        aspect_ratio, _ = extract_aspect_ratio(message.content)
 66        if aspect_ratio:
 67            width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
 68            params |= {"size": f"{width}x{height}"}
 69        if reference_images:
 70            if api_provider == "seedream":
 71                params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
 72            else:
 73                params["image"] = [Path(x) for x in reference_images]
 74        params |= {"model": model_id, "prompt": prompt}
 75        logger.debug(f"openai.images.generate(**{prettify(params)})")
 76    except Exception as e:
 77        logger.error(f"OpenAI client setup error: {e}")
 78        await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
 79        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 80    resp = {}
 81    for api_key in strings_list(api_keys, shuffle=True):
 82        try:
 83            openai_client["api_key"] = api_key
 84            logger.trace(f"AsyncOpenAI(**{openai_client})")
 85            openai = AsyncOpenAI(**openai_client)
 86            if reference_images and api_provider == "openai":
 87                params.pop("moderation", None)
 88                func = openai.images.edit
 89            else:
 90                func = openai.images.generate
 91            async for chunk in await func(**params):
 92                resp = trim_none(chunk.model_dump())
 93                if b64_json := resp.pop("b64_json", None):
 94                    image_bytes = base64.b64decode(b64_json)
 95                    save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{params.get('output_format', 'jpeg')}"
 96                    async with await anyio.open_file(save_path, "wb") as f:
 97                        await f.write(image_bytes)
 98                    rtype = resp.pop("type", None)
 99                    caption = f"{EMOJI_IMG_BOT}**{model_name}**\n" if rtype == "image_generation.completed" else f"⌛️**{model_name}** 中间结果...\n"
100                    status_msg = await status_msg.edit_media(InputMediaPhoto(str(save_path), caption=caption + prettify(resp)))
101                    save_path.unlink(missing_ok=True)
102                elif resp.pop("type", None) == "image_generation.completed":
103                    await status_msg.edit_caption(f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}")
104                logger.trace(resp)
105            if status_msg.photo:
106                return {"success": True}
107        except Exception as e:
108            logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
109            await modify_progress(status_msg, text=f"{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
110    return {"progress": status_msg} if isinstance(status_msg, Message) else {}