main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import base64
4from pathlib import Path
5from typing import Literal
6
7import anyio
8from loguru import logger
9from openai import AsyncOpenAI, DefaultAsyncHttpxClient
10from pyrogram.client import Client
11from pyrogram.types import InputMediaPhoto, Message
12
13from ai.images.post import get_image_contexts
14from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
15from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
16from config import DOWNLOAD_DIR, PROXY
17from messages.progress import modify_progress
18from utils import rand_string, strings_list
19
20
21async def openai_image_generation(
22 client: Client,
23 message: Message,
24 *,
25 model_id: str = "",
26 model_name: str = "",
27 api_keys: str = "",
28 client_config: str | dict = "",
29 generate_config: str | dict = "",
30 proxy: str | None = PROXY.OPENAI,
31 max_reference_img: int = 0,
32 resolution: Literal["1K", "2K", "4K"] = "1K",
33 max_width: int = int(1e16),
34 max_height: int = int(1e16),
35 max_size: int = int(1e32),
36 api_provider: Literal["openai", "seedream"] = "openai",
37 **kwargs,
38) -> dict[str, Message | bool]:
39 """Get OpenAI Image Generation.
40
41 Return:
42 dict[str, Message | bool]:
43 {"success": True} # Generated image successfully
44 {"progress": Message} # Failed to generate image
45 """
46 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
47 try:
48 openai_client = {}
49 if literal_eval(client_config):
50 openai_client |= literal_eval(client_config)
51 if proxy:
52 openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy, timeout=600)}
53 prompt, reference_images = await get_image_contexts(
54 client,
55 message,
56 max_reference_img=max_reference_img,
57 img_type="str" if api_provider == "openai" else "base64",
58 )
59 if not prompt:
60 await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
61 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
62 params: dict = {"stream": True, "timeout": 600}
63 if literal_eval(generate_config):
64 params |= literal_eval(generate_config)
65 aspect_ratio, _ = extract_aspect_ratio(message.content)
66 if aspect_ratio:
67 width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
68 params |= {"size": f"{width}x{height}"}
69 if reference_images:
70 if api_provider == "seedream":
71 params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
72 else:
73 params["image"] = [Path(x) for x in reference_images]
74 params |= {"model": model_id, "prompt": prompt}
75 logger.debug(f"openai.images.generate(**{prettify(params)})")
76 except Exception as e:
77 logger.error(f"OpenAI client setup error: {e}")
78 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
79 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
80 resp = {}
81 for api_key in strings_list(api_keys, shuffle=True):
82 try:
83 openai_client["api_key"] = api_key
84 logger.trace(f"AsyncOpenAI(**{openai_client})")
85 openai = AsyncOpenAI(**openai_client)
86 if reference_images and api_provider == "openai":
87 params.pop("moderation", None)
88 func = openai.images.edit
89 else:
90 func = openai.images.generate
91 async for chunk in await func(**params):
92 resp = trim_none(chunk.model_dump())
93 if b64_json := resp.pop("b64_json", None):
94 image_bytes = base64.b64decode(b64_json)
95 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{params.get('output_format', 'jpeg')}"
96 async with await anyio.open_file(save_path, "wb") as f:
97 await f.write(image_bytes)
98 rtype = resp.pop("type", None)
99 caption = f"{EMOJI_IMG_BOT}**{model_name}**\n" if rtype == "image_generation.completed" else f"⌛️**{model_name}** 中间结果...\n"
100 status_msg = await status_msg.edit_media(InputMediaPhoto(str(save_path), caption=caption + prettify(resp)))
101 save_path.unlink(missing_ok=True)
102 elif resp.pop("type", None) == "image_generation.completed":
103 await status_msg.edit_caption(f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}")
104 logger.trace(resp)
105 if status_msg.photo:
106 return {"success": True}
107 except Exception as e:
108 logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
109 await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
110 return {"progress": status_msg} if isinstance(status_msg, Message) else {}