Commit 5a65136
Changed files (3)
src
ai
src/ai/images/openai_img.py
@@ -5,20 +5,16 @@ from pathlib import Path
from typing import Literal
import anyio
-from glom import glom
from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
-from pyrogram.types import Message
+from pyrogram.types import InputMediaPhoto, Message
from ai.images.post import get_image_contexts
from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
from config import DOWNLOAD_DIR, PROXY
from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import delete_message
-from networking import download_file
from utils import rand_string, strings_list
@@ -37,6 +33,7 @@ async def openai_image_generation(
max_width: int = int(1e16),
max_height: int = int(1e16),
max_size: int = int(1e32),
+ api_provider: Literal["openai", "seedream"] = "openai",
**kwargs,
) -> dict[str, Message | bool]:
"""Get OpenAI Image Generation.
@@ -52,12 +49,17 @@ async def openai_image_generation(
if literal_eval(client_config):
openai_client |= literal_eval(client_config)
if proxy:
- openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
- prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
+ openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy, timeout=600)}
+ prompt, reference_images = await get_image_contexts(
+ client,
+ message,
+ max_reference_img=max_reference_img,
+ img_type="str" if api_provider == "openai" else "base64",
+ )
if not prompt:
await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
return {"progress": status_msg} if isinstance(status_msg, Message) else {}
- params = {}
+ params: dict = {"stream": True, "timeout": 600}
if literal_eval(generate_config):
params |= literal_eval(generate_config)
aspect_ratio, _ = extract_aspect_ratio(message.content)
@@ -65,7 +67,10 @@ async def openai_image_generation(
width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
params |= {"size": f"{width}x{height}"}
if reference_images:
- params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
+ if api_provider == "seedream":
+ params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
+ else:
+ params["image"] = [Path(x) for x in reference_images]
params |= {"model": model_id, "prompt": prompt}
logger.debug(f"openai.images.generate(**{prettify(params)})")
except Exception as e:
@@ -78,65 +83,28 @@ async def openai_image_generation(
openai_client["api_key"] = api_key
logger.trace(f"AsyncOpenAI(**{openai_client})")
openai = AsyncOpenAI(**openai_client)
- resp = await openai.images.generate(**params)
- resp = trim_none(resp.model_dump())
- if images := await download_generated_images(resp, proxy=proxy):
- resp.pop("data", None)
- caption = f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}\n"
- for idx, img in enumerate([x for x in images if x.get("url")]):
- caption += f"[P{idx + 1}原图]({img['url']})"
- await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
- await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
- await delete_message(status_msg)
+ if reference_images and api_provider == "openai":
+ params.pop("moderation", None)
+ func = openai.images.edit
+ else:
+ func = openai.images.generate
+ async for chunk in await func(**params):
+ resp = trim_none(chunk.model_dump())
+ if b64_json := resp.pop("b64_json", None):
+ image_bytes = base64.b64decode(b64_json)
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.{params.get('output_format', 'jpeg')}"
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(image_bytes)
+ rtype = resp.pop("type", None)
+ caption = f"{EMOJI_IMG_BOT}**{model_name}**\n" if rtype == "image_generation.completed" else f"⌛️**{model_name}** 中间结果...\n"
+ status_msg = await status_msg.edit_media(InputMediaPhoto(str(save_path), caption=caption + prettify(resp)))
+ save_path.unlink(missing_ok=True)
+ elif resp.pop("type", None) == "image_generation.completed":
+ await status_msg.edit_caption(f"{EMOJI_IMG_BOT}**{model_name}**\n{prettify(resp)}")
+ logger.trace(resp)
+ if status_msg.photo:
return {"success": True}
except Exception as e:
logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
return {"progress": status_msg} if isinstance(status_msg, Message) else {}
-
-
-async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
- """Download generated images.
-
- Response: {
- "model": "doubao-seedream-4-5-251128",
- "created": 1757321139,
- "data": [
- {
- "url": "https://...",
- "size": "3104x1312"
- },
- {
- "b64_json": "/9j/4AAQSkZJRgABA...",
- "size": "3104x1312"
- }
- ],
- "usage": {
- "generated_images": 2,
- "output_tokens": xxx,
- "total_tokens": xxx
- }
- }
-
- Return:
- [
- {
- "path": "/path/to/image.png"
- "url": "https://...",
- }
- ]
- """
- results = []
- data = glom(response, "data", default=[]) or []
- for item in data:
- if url := item.get("url"):
- img_path = await download_file(url, proxy=proxy)
- if Path(img_path).is_file():
- results.append({"path": img_path, "url": url})
- if b64_json := item.get("b64_json"):
- image_bytes = base64.b64decode(b64_json)
- save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
- async with await anyio.open_file(save_path, "wb") as f:
- await f.write(image_bytes)
- results.append({"path": save_path.as_posix()})
- return results
src/ai/images/post.py
@@ -112,7 +112,13 @@ def extract_metadata(response: dict) -> str:
return ""
-async def get_image_contexts(client: Client, message: Message, *, max_reference_img: int = 0) -> tuple[str, list[str]]:
+async def get_image_contexts(
+ client: Client,
+ message: Message,
+ *,
+ max_reference_img: int = 0,
+ img_type: Literal["base64", "str"] = "base64",
+) -> tuple[str, list[str]]:
"""Get image generation contexts.
Returns:
@@ -142,8 +148,13 @@ async def get_image_contexts(client: Client, message: Message, *, max_reference_
prompt = clean(msg.content)
if not msg.photo or len(images) >= max_reference_img:
continue
- res = await base64_media(client, msg)
- images.append(f"data:image/{res['ext']};base64,{res['base64']}")
+ if img_type == "base64":
+ res = await base64_media(client, msg)
+ images.append(f"data:image/{res['ext']};base64,{res['base64']}")
+ else:
+ path = await client.download_media(msg)
+ if Path(str(path)).is_file():
+ images.append(Path(str(path)).as_posix())
return prompt, images
src/ai/utils.py
@@ -18,7 +18,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM
from config import AI, PREFIX, PROXY, cache
from database.kv import get_cf_kv
-from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
+from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, ts_to_dt, zhcn
# ruff: noqa: RUF001
EMOJI_TEXT_BOT = "🤖"
@@ -66,6 +66,10 @@ def trim_none(obj: dict) -> dict:
def prettify(data: dict) -> str:
with contextlib.suppress(Exception):
data = trim_none(data)
+ if isinstance(data.get("created"), int):
+ data["created"] = ts_to_dt(data["created"]).strftime("%Y-%m-%d %H:%M:%S") # ty:ignore[unresolved-attribute]
+ if isinstance(data.get("created_at"), int):
+ data["created_at"] = ts_to_dt(data["created_at"]).strftime("%Y-%m-%d %H:%M:%S") # ty:ignore[unresolved-attribute]
return json.dumps(data, ensure_ascii=False, indent=2)
return str(data)