main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import base64
  5from pathlib import Path
  6from typing import Literal
  7
  8import anyio
  9from glom import glom
 10from loguru import logger
 11from pyrogram.client import Client
 12from pyrogram.types import Message
 13
 14from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
 15from ai.texts.contexts import base64_media
 16from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, prettify, replace_placeholder
 17from config import DOWNLOAD_DIR, PROXY
 18from messages.modify import message_modify
 19from messages.progress import modify_progress
 20from messages.sender import send2tg
 21from messages.utils import delete_message
 22from networking import download_file, hx_req
 23from utils import rand_string
 24
 25
 26async def http_post_image_generation(
 27    client: Client,
 28    message: Message,
 29    *,
 30    base_url: str = "",
 31    model_name: str = "",
 32    api_paths: dict | None = None,
 33    headers: dict | None = None,
 34    body: dict | None = None,
 35    extra_params: dict | None = None,
 36    proxy: str | None = PROXY.AI_POST,
 37    max_reference_img: int = 0,
 38    resolution: Literal["1K", "2K", "4K"] = "1K",
 39    max_width: int = int(1e16),
 40    max_height: int = int(1e16),
 41    max_size: int = int(1e32),
 42    **kwargs,
 43) -> dict[str, Message | bool]:
 44    """Get HTTP Post Image Generation.
 45
 46    Return:
 47        dict[str, Message | bool]:
 48            {"success": True}  # Generated image successfully
 49            {"progress": Message}  # Failed to generate image
 50    """
 51    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
 52    try:
 53        prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
 54        if not prompt:
 55            await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
 56            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 57        params = {}
 58        api_paths = api_paths or {}
 59        if headers:
 60            params |= {"headers": headers}
 61        if proxy:
 62            params |= {"proxy": proxy}
 63        url = base_url + api_paths.get("img_gen", "") if not reference_images else base_url + api_paths.get("img_edit", "")
 64        params |= {"url": url, "method": "POST"}
 65        if body:
 66            aspect_ratio, _ = extract_aspect_ratio(message.content)
 67            if aspect_ratio:
 68                width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
 69                body |= {"size": f"{width}x{height}"}
 70            params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt})}
 71        if extra_params:
 72            params |= extra_params
 73        logger.debug(f"hx_req(**{prettify(params)})")
 74        resp = await hx_req(**params)
 75        if error := resp.get("hx_error"):
 76            await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
 77            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 78
 79        image_urls: list[str] = []
 80        metadata = ""
 81        if task_id := resp.get("task_id"):  # ModelScope
 82            image_urls, metadata = await waiting_modelscope_task(task_id, params | {"base_url": base_url, "api_paths": api_paths})
 83        if glom(resp, "output.choices.0.message.content.0.image", default=""):  # DashScope Multimodal Generation
 84            image_urls.extend(glom(resp, "output.choices.0.message.content.*.image", default=[]))
 85            metadata = extract_metadata(resp)
 86        if image_urls:
 87            images = await download_generated_images(image_urls, proxy=proxy)
 88            caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{metadata}\n"
 89            for idx, img in enumerate([x for x in images if x.get("url")]):
 90                caption += f"[P{idx + 1}原图]({img['url']})"
 91            await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
 92            await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
 93            await delete_message(status_msg)
 94            return {"success": True}
 95    except Exception as e:
 96        logger.error(f"HTTP Post Image Generation error: {e}")
 97        await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
 98    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 99
100
101def extract_metadata(response: dict) -> str:
102    """Extract some useful metadata from response.
103
104    These information will be sent to Telegram caption.
105    """
106    if glom(response, "input.prompt", default=""):  # ModelScope
107        metadata = response.get("input", {})
108        metadata.pop("prompt", None)
109        return prettify(metadata)
110    if glom(response, "output.choices.0.message.content.0.image", default=""):  # DashScope Multimodal Generation
111        return prettify(response.get("usage", ""))
112    return ""
113
114
115async def get_image_contexts(
116    client: Client,
117    message: Message,
118    *,
119    max_reference_img: int = 0,
120    img_type: Literal["base64", "str"] = "base64",
121) -> tuple[str, list[str]]:
122    """Get image generation contexts.
123
124    Returns:
125        tuple: prompt, list_of_images
126    """
127
128    def clean(text: str) -> str:
129        if not text:
130            return ""
131        text = clean_cmd_prefix(str(text))
132        _, text = extract_aspect_ratio(text)
133        return text.strip()
134
135    if not max_reference_img:
136        return clean(message.content), []
137    messages = [message]
138    while message.reply_to_message:
139        message = message.reply_to_message
140        if not message.service:  # ignore service messages
141            messages.append(message_modify(message))
142    messages.reverse()  # old to new
143    images = []
144    prompt = ""
145    for m in messages:
146        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
147        for msg in group_messages:
148            prompt = clean(msg.content)
149            if not msg.photo or len(images) >= max_reference_img:
150                continue
151            if img_type == "base64":
152                res = await base64_media(client, msg)
153                images.append(f"data:image/{res['ext']};base64,{res['base64']}")
154            else:
155                path = await client.download_media(msg)
156                if Path(str(path)).is_file():
157                    images.append(Path(str(path)).as_posix())
158    return prompt, images
159
160
161async def download_generated_images(image_urls: list[str], proxy: str | None) -> list[dict]:
162    """Download generated images.
163
164    Return:
165    [
166        {
167            "path": "/path/to/image.png"
168            "url": "https://...",
169        }
170    ]
171    """
172    results = []
173    for url in image_urls:
174        if url.startswith("http"):
175            img_path = await download_file(url, impersonate=None, proxy=proxy)
176            if Path(img_path).is_file():
177                results.append({"path": img_path, "url": url})
178        else:  # base64 json
179            image_bytes = base64.b64decode(url)
180            save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
181            async with await anyio.open_file(save_path, "wb") as f:
182                await f.write(image_bytes)
183            results.append({"path": save_path.as_posix()})
184    return results
185
186
187async def waiting_modelscope_task(task_id: str, params: dict) -> tuple[list[str], str]:
188    """Waiting for async task to be SUCCEED.
189
190    Task Submmited Response:
191    {
192        "task_status": "SUCCEED",
193        "task_id": "5054288",
194        "request_id": "68b91aef-e00e-40dd-a4e8-8611e4826b7a"
195    }
196
197    Task Check Response:
198    {
199      "input": {
200        "guidanceScale": 7.5,
201        "height": 1280,
202        "negativePrompt": "",
203        "numInferenceSteps": 9,
204        "outputs": {},
205        "prompt": "充满活力的特写编辑肖像, 模特眼神犀利, 头戴雕塑感帽子, 色彩拼接丰富, 眼部焦点锐利, 景深较浅, 具有Vogue杂志封面的美学风格, 采用中画幅拍摄, 工作室灯光效果强烈.",
206        "sampler": "Euler a",
207        "seed": 391608538,
208        "timeTaken": 6764.005661010742,
209        "weight": 0
210      },
211      "output_images": [
212        "https://muse-ai.oss-cn-hangzhou.aliyuncs.com/img/8cac52d9b76a41238f02733ef3709fba.png"
213      ],
214      "request_id": "b96147a2-d9ea-4640-b726-4d097235c5a1",
215      "task_id": "",
216      "task_status": "SUCCEED",
217      "time_taken": 6764.005661010742
218    }
219
220    Returns:
221        tuple: list of images, metadata
222    """
223    headers = {k.lower(): v for k, v in params.get("headers", {}).items()} | {"x-modelscope-task-type": "image_generation"}
224    base_url = params["base_url"]
225    if base_url.startswith("https://gateway.helicone.ai"):
226        helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
227        base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
228        headers.pop("helicone-target-url", None)
229        headers.pop("helicone-auth", None)
230    task_url = base_url + glom(params, "api_paths.task_check", default="")
231    url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
232    resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])
233    while True:
234        if "hx_error" in resp or resp["task_status"].upper() in {"FAILED", "CANCELLED", "UNKNOWN"}:
235            logger.error(f"Image Generation Task {task_id} error: {resp}")
236            return [], ""
237        if resp["task_status"] == "SUCCEED":
238            return glom(resp, "output_images", default=[]), extract_metadata(resp)
239
240        await asyncio.sleep(5)
241        resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])