main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import base64
4from pathlib import Path
5from typing import Literal
6
7import anyio
8from glom import glom
9from loguru import logger
10from openai import AsyncOpenAI, DefaultAsyncHttpxClient
11from pyrogram.client import Client
12from pyrogram.types import Message
13
14from ai.images.post import get_image_contexts
15from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
16from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
17from config import DOWNLOAD_DIR, PROXY
18from messages.progress import modify_progress
19from messages.sender import send2tg
20from messages.utils import delete_message
21from networking import download_file
22from utils import rand_string, strings_list
23
24
25async def openai_image_generation(
26 client: Client,
27 message: Message,
28 *,
29 model_id: str = "",
30 model_name: str = "",
31 api_keys: str = "",
32 client_config: str | dict = "",
33 generate_config: str | dict = "",
34 proxy: str | None = PROXY.OPENAI,
35 max_reference_img: int = 0,
36 resolution: Literal["1K", "2K", "4K"] = "1K",
37 max_width: int = int(1e16),
38 max_height: int = int(1e16),
39 max_size: int = int(1e32),
40 api_provider: Literal["openai", "seedream"] = "openai",
41 **kwargs,
42) -> dict[str, Message | bool]:
43 """Get OpenAI Image Generation.
44
45 Return:
46 dict[str, Message | bool]:
47 {"success": True} # Generated image successfully
48 {"progress": Message} # Failed to generate image
49 """
50 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
51 try:
52 openai_client = {}
53 if literal_eval(client_config):
54 openai_client |= literal_eval(client_config)
55 if proxy:
56 openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
57 prompt, reference_images = await get_image_contexts(
58 client,
59 message,
60 max_reference_img=max_reference_img,
61 img_type="str" if api_provider == "openai" else "base64",
62 )
63 if not prompt:
64 await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
65 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
66 params = {}
67 if literal_eval(generate_config):
68 params |= literal_eval(generate_config)
69 aspect_ratio, _ = extract_aspect_ratio(message.content)
70 if aspect_ratio:
71 width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
72 params |= {"size": f"{width}x{height}"}
73 if reference_images:
74 if api_provider == "seedream":
75 params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
76 else:
77 params["image"] = [Path(x) for x in reference_images]
78 params |= {"model": model_id, "prompt": prompt}
79 logger.debug(f"openai.images.generate(**{prettify(params)})")
80 except Exception as e:
81 logger.error(f"OpenAI client setup error: {e}")
82 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
83 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
84 resp = {}
85 for api_key in strings_list(api_keys, shuffle=True):
86 try:
87 openai_client["api_key"] = api_key
88 logger.trace(f"AsyncOpenAI(**{openai_client})")
89 openai = AsyncOpenAI(**openai_client)
90 if reference_images and api_provider == "openai":
91 params.pop("moderation", None)
92 resp = await openai.images.edit(**params)
93 else:
94 resp = await openai.images.generate(**params)
95 resp = trim_none(resp.model_dump())
96 if images := await download_generated_images(resp, proxy=proxy):
97 resp.pop("data", None)
98 caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}\n"
99 for idx, img in enumerate([x for x in images if x.get("url")]):
100 caption += f"[P{idx + 1}原图]({img['url']})"
101 await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
102 await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
103 await delete_message(status_msg)
104 return {"success": True}
105 except Exception as e:
106 logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
107 await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
108 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
109
110
111async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
112 """Download generated images.
113
114 Response: {
115 "model": "doubao-seedream-4-5-251128",
116 "created": 1757321139,
117 "data": [
118 {
119 "url": "https://...",
120 "size": "3104x1312"
121 },
122 {
123 "b64_json": "/9j/4AAQSkZJRgABA...",
124 "size": "3104x1312"
125 }
126 ],
127 "usage": {
128 "generated_images": 2,
129 "output_tokens": xxx,
130 "total_tokens": xxx
131 }
132 }
133
134 Return:
135 [
136 {
137 "path": "/path/to/image.png"
138 "url": "https://...",
139 }
140 ]
141 """
142 results = []
143 data = glom(response, "data", default=[]) or []
144 for item in data:
145 if url := item.get("url"):
146 img_path = await download_file(url, proxy=proxy)
147 if Path(img_path).is_file():
148 results.append({"path": img_path, "url": url})
149 if b64_json := item.get("b64_json"):
150 image_bytes = base64.b64decode(b64_json)
151 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
152 async with await anyio.open_file(save_path, "wb") as f:
153 await f.write(image_bytes)
154 results.append({"path": save_path.as_posix()})
155 return results