Commit f94f57b
Changed files (5)
src
src/llm/ali/zimage.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import base64
+import json
+from pathlib import Path
+
+import anyio
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import DOWNLOAD_DIR, TEXT2IMG
+from llm.utils import parse_as_dict
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from networking import hx_req
+from utils import rand_string
+
+
+async def zimage_text2img(client: Client, message: Message, prompt: str, *, silent: bool = False, **kwargs):
+ """Z-Image text to image.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ prompt (str): Prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+ """
+ if not prompt:
+ if message.reply_to_message:
+ prompt = message.reply_to_message.content
+ else:
+ await message.reply(text="请输入图片描述。", quote=True)
+ return
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**Z-Image**:\n{prompt}", **kwargs))[0]
+ resp = await hx_req(
+ TEXT2IMG.ZIMAGE_API_URL,
+ "POST",
+ headers={"Content-Type": "application/json"},
+ json_data=parse_as_dict(prompt) or {"prompt": prompt},
+ proxy=TEXT2IMG.ZIMAGE_PROXY,
+ check_kv={"mime_type": "image/png"},
+ timeout=600,
+ silent=True,
+ )
+ if b64_json := resp.get("b64_json"):
+ image_bytes = base64.b64decode(b64_json)
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(image_bytes)
+ media = [{"photo": save_path.as_posix()}]
+ await send2tg(client, message, texts=json.dumps(resp["params"], ensure_ascii=False, indent=2), media=media, **kwargs)
+ await modify_progress(del_status=True, **kwargs)
src/llm/text2img.py
@@ -9,6 +9,7 @@ from pyrogram.types import Message
from config import PREFIX, TEXT2IMG
from llm.ali.text2img import ali_text2img
+from llm.ali.zimage import zimage_text2img
from llm.cloudflare.text2img import cloudflare_text2img
from llm.doubao.text2img import doubao_genimg
from llm.gemini.text2img import gemini_text2img
@@ -20,9 +21,10 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
⚙️模型配置:
- `{PREFIX.GENIMG}`: 默认模型 ({TEXT2IMG.DEFAULT_MODEL})
-- `/sd`: 豆包Seedream模型
-- `/flux`: Flux模型
-- `/stable`: Stable Diffusion模型
+- `/z`: 阿里Z-Image
+- `/sd`: 豆包Seedream
+- `/flux`: Flux
+- `/stable`: Stable Diffusion
- `/nano`: Gemini Nano Banana
上下文说明:
@@ -65,13 +67,15 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
provider, model_id = model.split("/", 1)
try:
if provider == "gemini":
- return await gemini_text2img(client, message, model_id, prompt, **kwargs)
- if provider == "ali":
- return await ali_text2img(client, message, model_id, prompt, **kwargs)
- if provider == "cloudflare":
- return await cloudflare_text2img(client, message, model_id, prompt, **kwargs)
- if provider == "doubao":
- return await doubao_genimg(client, message, model_id, prompt, **kwargs)
+ await gemini_text2img(client, message, model_id, prompt, **kwargs)
+ elif provider == "ali":
+ await ali_text2img(client, message, model_id, prompt, **kwargs)
+ elif provider == "cloudflare":
+ await cloudflare_text2img(client, message, model_id, prompt, **kwargs)
+ elif provider == "doubao":
+ await doubao_genimg(client, message, model_id, prompt, **kwargs)
+ elif provider == "zimage":
+ await zimage_text2img(client, message, prompt, **kwargs)
except Exception as e:
logger.error(e)
return {}
@@ -103,4 +107,6 @@ def enabled_models() -> dict[str, list]:
models["sd"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_STABLE_DIFFUSION_MODEL)])
if provider == "doubao" and TEXT2IMG.DOUBAO_SEEDREAM_MODEL:
models["doubao"].extend([f"doubao/{model}" for model in strings_list(TEXT2IMG.DOUBAO_SEEDREAM_MODEL)])
+ if provider == "zimage" and TEXT2IMG.ZIMAGE_API_URL:
+ models["zimage"].extend(["zimage/Z-Image"])
return models
src/llm/utils.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import ast
import random
import re
import tempfile
@@ -53,6 +54,8 @@ def enabled_providers() -> tuple[list[str], list[str]]:
img_providers.append("cloudflare")
if all([TEXT2IMG.DOUBAO_API_KEY, TEXT2IMG.DOUBAO_SEEDREAM_MODEL]):
img_providers.append("doubao")
+ if all([TEXT2IMG.ZIMAGE_API_URL]):
+ img_providers.append("zimage")
return text_providers, img_providers
@@ -102,6 +105,20 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
return 0
+def parse_as_dict(s: str) -> dict:
+ """Parse the given string as a dictionary."""
+ s = re.sub(r"\btrue\b", "True", s)
+ s = re.sub(r"\bfalse\b", "False", s)
+ s = re.sub(r"\bnull\b", "None", s)
+ try:
+ data = ast.literal_eval(s)
+ if isinstance(data, dict):
+ return data
+ except (ValueError, SyntaxError):
+ return {}
+ return {}
+
+
def beautify_model_name(name: str) -> str:
"""Beautify model name.
src/others/alias.py
@@ -17,6 +17,8 @@ def command_alias(message: Message) -> Message:
texts = texts.replace("/stable", f"{PREFIX.GENIMG} @sd")
elif texts.startswith("/nano"):
texts = texts.replace("/nano", f"{PREFIX.GENIMG} @gemini")
+ elif texts.startswith("/z"):
+ texts = texts.replace("/z", f"{PREFIX.GENIMG} @zimage")
if message.text:
message.text = Str(texts)
src/config.py
@@ -485,6 +485,8 @@ class TEXT2IMG:
ALI_FLUX_MODEL = os.getenv("TEXT2IMG_ALI_FLUX_MODEL", "flux-dev")
ALI_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_ALI_STABLE_DIFFUSION_MODEL", "stable-diffusion-3.5-large")
ALI_PROXY = os.getenv("TEXT2IMG_ALI_PROXY", None)
+ ZIMAGE_API_URL = os.getenv("TEXT2IMG_ZIMAGE_API_URL", "")
+ ZIMAGE_PROXY = os.getenv("TEXT2IMG_ZIMAGE_PROXY", None)
CF_API_KEY = os.getenv("TEXT2IMG_CF_API_KEY", "") # comma separated keys. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
CF_FLUX_MODEL = os.getenv("TEXT2IMG_CF_FLUX_MODEL", "@cf/black-forest-labs/flux-1-schnell")
CF_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_CF_STABLE_DIFFUSION_MODEL", "@cf/bytedance/stable-diffusion-xl-lightning")