main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import base64
  4from pathlib import Path
  5from typing import Literal
  6
  7import anyio
  8from glom import glom
  9from loguru import logger
 10from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 11from pyrogram.client import Client
 12from pyrogram.types import Message
 13
 14from ai.images.post import get_image_contexts
 15from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
 16from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
 17from config import DOWNLOAD_DIR, PROXY
 18from messages.progress import modify_progress
 19from messages.sender import send2tg
 20from messages.utils import delete_message
 21from networking import download_file
 22from utils import rand_string, strings_list
 23
 24
 25async def openai_image_generation(
 26    client: Client,
 27    message: Message,
 28    *,
 29    model_id: str = "",
 30    model_name: str = "",
 31    api_keys: str = "",
 32    client_config: str | dict = "",
 33    generate_config: str | dict = "",
 34    proxy: str | None = PROXY.OPENAI,
 35    max_reference_img: int = 0,
 36    resolution: Literal["1K", "2K", "4K"] = "1K",
 37    max_width: int = int(1e16),
 38    max_height: int = int(1e16),
 39    max_size: int = int(1e32),
 40    api_provider: Literal["openai", "seedream"] = "openai",
 41    **kwargs,
 42) -> dict[str, Message | bool]:
 43    """Get OpenAI Image Generation.
 44
 45    Return:
 46    dict[str, Message | bool]:
 47        {"success": True}  # Generated image successfully
 48        {"progress": Message}  # Failed to generate image
 49    """
 50    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
 51    try:
 52        openai_client = {}
 53        if literal_eval(client_config):
 54            openai_client |= literal_eval(client_config)
 55        if proxy:
 56            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
 57        prompt, reference_images = await get_image_contexts(
 58            client,
 59            message,
 60            max_reference_img=max_reference_img,
 61            img_type="str" if api_provider == "openai" else "base64",
 62        )
 63        if not prompt:
 64            await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
 65            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 66        params = {}
 67        if literal_eval(generate_config):
 68            params |= literal_eval(generate_config)
 69        aspect_ratio, _ = extract_aspect_ratio(message.content)
 70        if aspect_ratio:
 71            width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
 72            params |= {"size": f"{width}x{height}"}
 73        if reference_images:
 74            if api_provider == "seedream":
 75                params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
 76            else:
 77                params["image"] = [Path(x) for x in reference_images]
 78        params |= {"model": model_id, "prompt": prompt}
 79        logger.debug(f"openai.images.generate(**{prettify(params)})")
 80    except Exception as e:
 81        logger.error(f"OpenAI client setup error: {e}")
 82        await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
 83        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 84    resp = {}
 85    for api_key in strings_list(api_keys, shuffle=True):
 86        try:
 87            openai_client["api_key"] = api_key
 88            logger.trace(f"AsyncOpenAI(**{openai_client})")
 89            openai = AsyncOpenAI(**openai_client)
 90            if reference_images and api_provider == "openai":
 91                params.pop("moderation", None)
 92                resp = await openai.images.edit(**params)
 93            else:
 94                resp = await openai.images.generate(**params)
 95            resp = trim_none(resp.model_dump())
 96            if images := await download_generated_images(resp, proxy=proxy):
 97                resp.pop("data", None)
 98                caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}\n"
 99                for idx, img in enumerate([x for x in images if x.get("url")]):
100                    caption += f"[P{idx + 1}原图]({img['url']})"
101                await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
102                await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
103                await delete_message(status_msg)
104                return {"success": True}
105        except Exception as e:
106            logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
107            await modify_progress(status_msg, text=f"{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
108    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
109
110
111async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
112    """Download generated images.
113
114    Response: {
115        "model": "doubao-seedream-4-5-251128",
116        "created": 1757321139,
117        "data": [
118            {
119                "url": "https://...",
120                "size": "3104x1312"
121            },
122            {
123                "b64_json": "/9j/4AAQSkZJRgABA...",
124                "size": "3104x1312"
125            }
126        ],
127        "usage": {
128            "generated_images": 2,
129            "output_tokens": xxx,
130            "total_tokens": xxx
131        }
132    }
133
134    Return:
135    [
136        {
137            "path": "/path/to/image.png"
138            "url": "https://...",
139        }
140    ]
141    """
142    results = []
143    data = glom(response, "data", default=[]) or []
144    for item in data:
145        if url := item.get("url"):
146            img_path = await download_file(url, proxy=proxy)
147            if Path(img_path).is_file():
148                results.append({"path": img_path, "url": url})
149        if b64_json := item.get("b64_json"):
150            image_bytes = base64.b64decode(b64_json)
151            save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
152            async with await anyio.open_file(save_path, "wb") as f:
153                await f.write(image_bytes)
154            results.append({"path": save_path.as_posix()})
155    return results