main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5from pathlib import Path
6from typing import Literal
7
8import anyio
9from glom import glom
10from loguru import logger
11from pyrogram.client import Client
12from pyrogram.types import Message
13
14from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
15from ai.texts.contexts import base64_media
16from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, prettify, replace_placeholder
17from config import DOWNLOAD_DIR, PROXY
18from messages.modify import message_modify
19from messages.progress import modify_progress
20from messages.sender import send2tg
21from messages.utils import delete_message
22from networking import download_file, hx_req
23from utils import rand_string
24
25
26async def http_post_image_generation(
27 client: Client,
28 message: Message,
29 *,
30 base_url: str = "",
31 model_name: str = "",
32 api_paths: dict | None = None,
33 headers: dict | None = None,
34 body: dict | None = None,
35 extra_params: dict | None = None,
36 proxy: str | None = PROXY.AI_POST,
37 max_reference_img: int = 0,
38 resolution: Literal["1K", "2K", "4K"] = "1K",
39 max_width: int = int(1e16),
40 max_height: int = int(1e16),
41 max_size: int = int(1e32),
42 **kwargs,
43) -> dict[str, Message | bool]:
44 """Get HTTP Post Image Generation.
45
46 Return:
47 dict[str, Message | bool]:
48 {"success": True} # Generated image successfully
49 {"progress": Message} # Failed to generate image
50 """
51 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
52 try:
53 prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
54 if not prompt:
55 await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
56 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
57 params = {}
58 api_paths = api_paths or {}
59 if headers:
60 params |= {"headers": headers}
61 if proxy:
62 params |= {"proxy": proxy}
63 url = base_url + api_paths.get("img_gen", "") if not reference_images else base_url + api_paths.get("img_edit", "")
64 params |= {"url": url, "method": "POST"}
65 if body:
66 aspect_ratio, _ = extract_aspect_ratio(message.content)
67 if aspect_ratio:
68 width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
69 body |= {"size": f"{width}x{height}"}
70 params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt})}
71 if extra_params:
72 params |= extra_params
73 logger.debug(f"hx_req(**{prettify(params)})")
74 resp = await hx_req(**params)
75 if error := resp.get("hx_error"):
76 await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
77 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
78
79 image_urls: list[str] = []
80 metadata = ""
81 if task_id := resp.get("task_id"): # ModelScope
82 image_urls, metadata = await waiting_modelscope_task(task_id, params | {"base_url": base_url, "api_paths": api_paths})
83 if glom(resp, "output.choices.0.message.content.0.image", default=""): # DashScope Multimodal Generation
84 image_urls.extend(glom(resp, "output.choices.0.message.content.*.image", default=[]))
85 metadata = extract_metadata(resp)
86 if image_urls:
87 images = await download_generated_images(image_urls, proxy=proxy)
88 caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{metadata}\n"
89 for idx, img in enumerate([x for x in images if x.get("url")]):
90 caption += f"[P{idx + 1}原图]({img['url']})"
91 await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
92 await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
93 await delete_message(status_msg)
94 return {"success": True}
95 except Exception as e:
96 logger.error(f"HTTP Post Image Generation error: {e}")
97 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
98 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
99
100
101def extract_metadata(response: dict) -> str:
102 """Extract some useful metadata from response.
103
104 These information will be sent to Telegram caption.
105 """
106 if glom(response, "input.prompt", default=""): # ModelScope
107 metadata = response.get("input", {})
108 metadata.pop("prompt", None)
109 return prettify(metadata)
110 if glom(response, "output.choices.0.message.content.0.image", default=""): # DashScope Multimodal Generation
111 return prettify(response.get("usage", ""))
112 return ""
113
114
115async def get_image_contexts(
116 client: Client,
117 message: Message,
118 *,
119 max_reference_img: int = 0,
120 img_type: Literal["base64", "str"] = "base64",
121) -> tuple[str, list[str]]:
122 """Get image generation contexts.
123
124 Returns:
125 tuple: prompt, list_of_images
126 """
127
128 def clean(text: str) -> str:
129 if not text:
130 return ""
131 text = clean_cmd_prefix(str(text))
132 _, text = extract_aspect_ratio(text)
133 return text.strip()
134
135 if not max_reference_img:
136 return clean(message.content), []
137 messages = [message]
138 while message.reply_to_message:
139 message = message.reply_to_message
140 if not message.service: # ignore service messages
141 messages.append(message_modify(message))
142 messages.reverse() # old to new
143 images = []
144 prompt = ""
145 for m in messages:
146 group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
147 for msg in group_messages:
148 prompt = clean(msg.content)
149 if not msg.photo or len(images) >= max_reference_img:
150 continue
151 if img_type == "base64":
152 res = await base64_media(client, msg)
153 images.append(f"data:image/{res['ext']};base64,{res['base64']}")
154 else:
155 path = await client.download_media(msg)
156 if Path(str(path)).is_file():
157 images.append(Path(str(path)).as_posix())
158 return prompt, images
159
160
161async def download_generated_images(image_urls: list[str], proxy: str | None) -> list[dict]:
162 """Download generated images.
163
164 Return:
165 [
166 {
167 "path": "/path/to/image.png"
168 "url": "https://...",
169 }
170 ]
171 """
172 results = []
173 for url in image_urls:
174 if url.startswith("http"):
175 img_path = await download_file(url, impersonate=None, proxy=proxy)
176 if Path(img_path).is_file():
177 results.append({"path": img_path, "url": url})
178 else: # base64 json
179 image_bytes = base64.b64decode(url)
180 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
181 async with await anyio.open_file(save_path, "wb") as f:
182 await f.write(image_bytes)
183 results.append({"path": save_path.as_posix()})
184 return results
185
186
187async def waiting_modelscope_task(task_id: str, params: dict) -> tuple[list[str], str]:
188 """Waiting for async task to be SUCCEED.
189
190 Task Submmited Response:
191 {
192 "task_status": "SUCCEED",
193 "task_id": "5054288",
194 "request_id": "68b91aef-e00e-40dd-a4e8-8611e4826b7a"
195 }
196
197 Task Check Response:
198 {
199 "input": {
200 "guidanceScale": 7.5,
201 "height": 1280,
202 "negativePrompt": "",
203 "numInferenceSteps": 9,
204 "outputs": {},
205 "prompt": "充满活力的特写编辑肖像, 模特眼神犀利, 头戴雕塑感帽子, 色彩拼接丰富, 眼部焦点锐利, 景深较浅, 具有Vogue杂志封面的美学风格, 采用中画幅拍摄, 工作室灯光效果强烈.",
206 "sampler": "Euler a",
207 "seed": 391608538,
208 "timeTaken": 6764.005661010742,
209 "weight": 0
210 },
211 "output_images": [
212 "https://muse-ai.oss-cn-hangzhou.aliyuncs.com/img/8cac52d9b76a41238f02733ef3709fba.png"
213 ],
214 "request_id": "b96147a2-d9ea-4640-b726-4d097235c5a1",
215 "task_id": "",
216 "task_status": "SUCCEED",
217 "time_taken": 6764.005661010742
218 }
219
220 Returns:
221 tuple: list of images, metadata
222 """
223 headers = {k.lower(): v for k, v in params.get("headers", {}).items()} | {"x-modelscope-task-type": "image_generation"}
224 base_url = params["base_url"]
225 if base_url.startswith("https://gateway.helicone.ai"):
226 helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
227 base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
228 headers.pop("helicone-target-url", None)
229 headers.pop("helicone-auth", None)
230 task_url = base_url + glom(params, "api_paths.task_check", default="")
231 url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
232 resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])
233 while True:
234 if "hx_error" in resp or resp["task_status"].upper() in {"FAILED", "CANCELLED", "UNKNOWN"}:
235 logger.error(f"Image Generation Task {task_id} error: {resp}")
236 return [], ""
237 if resp["task_status"] == "SUCCEED":
238 return glom(resp, "output_images", default=[]), extract_metadata(resp)
239
240 await asyncio.sleep(5)
241 resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])