Commit 0b773fa

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-11-29 12:51:16
feat(ai): allow advanced prompt configuration for text2img APIs
1 parent 391e0f5
Changed files (6)
src/llm/ali/text2img.py
@@ -12,6 +12,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import TEXT2IMG
+from llm.utils import parse_as_dict
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from networking import download_file, hx_req
@@ -40,20 +41,21 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
             return {}
     model_name = model_id.split("/")[-1].title()
     if not silent and kwargs.get("show_progress"):
-        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\nπŸ’¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n{prompt}", **kwargs))[0]
     error = ""
     succ = False
+    parsed = parse_as_dict(prompt, need_prefix="config=")
     payload = {
         "model": model_id,
-        "input": {"prompt": prompt},
+        "input": {"prompt": glom(parsed, "input.prompt", default=prompt)},
         "parameters": {
-            "size": "1024*1024",
-            "steps": 50,
-            "seed": randint(0, 2147483647),
+            "size": glom(parsed, "parameters.size", default="1024*1024"),
+            "steps": glom(parsed, "parameters.steps", default=50),
+            "seed": glom(parsed, "parameters.seed", default=randint(0, 2147483647)),
         },
     }
     if "stable-diffusion" in model_id:
-        payload |= {"parameters": {"n": 4}}
+        payload |= {"parameters": {"n": glom(parsed, "parameters.n", default=4)}}
     for api_key in strings_list(TEXT2IMG.ALI_API_KEY, shuffle=True):
         headers = {
             "X-DashScope-Async": "enable",
@@ -75,6 +77,10 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
                 await send2tg(client, message, texts=json.dumps(payload, ensure_ascii=False, indent=2), media=media, **kwargs)
                 succ = True
                 break
+            if finished.get("error"):
+                error = finished["error"]
+                logger.error(error)
+                continue
         except Exception as e:
             logger.error(e)
     if error and not succ:
src/llm/ali/zimage.py
@@ -37,7 +37,7 @@ async def zimage_text2img(client: Client, message: Message, prompt: str, *, sile
         TEXT2IMG.ZIMAGE_API_URL,
         "POST",
         headers={"Content-Type": "application/json"},
-        json_data=parse_as_dict(prompt) or {"prompt": prompt},
+        json_data=parse_as_dict(prompt, need_prefix="config=") or {"prompt": prompt},
         proxy=TEXT2IMG.ZIMAGE_PROXY,
         check_kv={"mime_type": "image/png"},
         timeout=600,
@@ -49,5 +49,5 @@ async def zimage_text2img(client: Client, message: Message, prompt: str, *, sile
         async with await anyio.open_file(save_path, "wb") as f:
             await f.write(image_bytes)
         media = [{"photo": save_path.as_posix()}]
-        await send2tg(client, message, texts=json.dumps(resp["params"], ensure_ascii=False, indent=2), media=media, **kwargs)
+        await send2tg(client, message, texts="🌠**Z-Image**:\n" + json.dumps(resp["params"], ensure_ascii=False, indent=2), media=media, **kwargs)
     await modify_progress(del_status=True, **kwargs)
src/llm/cloudflare/text2img.py
@@ -33,7 +33,7 @@ async def cloudflare_text2img(client: Client, message: Message, model_id: str, p
 
     model_name = model_id.split("/")[-1].title()
     if not silent and kwargs.get("show_progress"):
-        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\nπŸ’¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n{prompt}", **kwargs))[0]
     for api_key in strings_list(TEXT2IMG.CF_API_KEY, shuffle=True):
         account_id, token = api_key.split(":")
         resp = await hx_req(
@@ -45,11 +45,14 @@ async def cloudflare_text2img(client: Client, message: Message, model_id: str, p
             proxy=TEXT2IMG.CF_PROXY,
             rformat="content",
         )
+        if error := resp.get("hx_raw"):
+            await modify_progress(text="βŒη”Ÿζˆε€±θ΄₯\n" + json.dumps(error, ensure_ascii=False, indent=2), force_update=True, **kwargs)
+            continue
         path = save_img(resp["content"])
         if path.is_file():
-            await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=[{"photo": path}], **kwargs)
+            await send2tg(client, message, texts=f"🌠**{model_name}**:\n{prompt}", media=[{"photo": path}], **kwargs)
+            await modify_progress(del_status=True, **kwargs)
             break
-    await modify_progress(del_status=True, **kwargs)
     return {}
 
 
src/llm/doubao/text2img.py
@@ -11,6 +11,7 @@ from pyrogram.types import Message
 
 from config import TEXT2IMG
 from llm.contexts import base64_media
+from llm.utils import parse_as_dict
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from networking import download_file, hx_req
@@ -36,25 +37,32 @@ async def doubao_genimg(client: Client, message: Message, model_id: str, prompt:
         return {}
     model_name = model_id.split("/")[-1].title()
     if not silent and kwargs.get("show_progress"):
-        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\nπŸ’¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n{prompt}", **kwargs))[0]
     error = ""
     succ = False
-    config = {"model": model_id, "prompt": prompt, "size": "4K", "watermark": False, "seed": randint(0, 2147483647)}
+    parsed = parse_as_dict(prompt, need_prefix="config=")
+    config = {
+        "model": model_id,
+        "prompt": parsed.get("prompt", prompt),
+        "size": parsed.get("size", "4K"),
+        "watermark": parsed.get("watermark", False),
+        "seed": parsed.get("seed", randint(0, 2147483647)),
+    }
     images = await get_ctx_images(client, message)
     payload = config | {"image": images} if images else config
     for api_key in strings_list(TEXT2IMG.DOUBAO_API_KEY, shuffle=True):
         headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
         api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
-        resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, check_keys=["data"])
+        resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, max_retry=0)
         if url := glom(resp, "data.0.url", default=""):
             img_path = await download_file(url, proxy=TEXT2IMG.DOUBAO_PROXY)
             if Path(img_path).is_file():
-                caption = f"[δΈ‹θ½½εŽŸε›Ύ]({url}) (24hε†…ζœ‰ζ•ˆ)\n{json.dumps(config, ensure_ascii=False, indent=2)}"
+                caption = f"🌠**{model_name}**:\n{json.dumps(config, ensure_ascii=False, indent=2)}\n[δΈ‹θ½½εŽŸε›Ύ]({url}) (24hε†…ζœ‰ζ•ˆ)"
                 await send2tg(client, message, texts=caption, media=[{"photo": img_path}], **kwargs)
                 succ = True
                 break
-        elif error_msg := glom(resp, "data.error.message", default=""):
-            await modify_progress(text=f"❌{error_msg}", force_update=True, **kwargs)
+        elif error := resp.get("hx_raw"):
+            await modify_progress(text=f"βŒη”Ÿζˆε€±θ΄₯\n{json.dumps(error, ensure_ascii=False, indent=2)}", force_update=True, **kwargs)
             logger.error(error)
             continue
     if succ:
src/llm/gemini/text2img.py
@@ -60,6 +60,9 @@ async def gemini_text2img(
             await app.aio.aclose()
             caption = ""
             media = []
+            if glom(response, "candidates.0.finish_reason.name", default="STOP") != "STOP":
+                await modify_progress(text="βŒη”Ÿζˆε€±θ΄₯: " + glom(response, "candidates.0.finish_reason.name"), **kwargs)
+                continue
             for part in flatten(glom(response, "candidates.*.content.parts", default=[])):
                 if part.text:
                     caption += part.text
src/llm/utils.py
@@ -105,8 +105,15 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
         return 0
 
 
-def parse_as_dict(s: str) -> dict:
-    """Parse the given string as a dictionary."""
+def parse_as_dict(s: str, need_prefix: str | None = None) -> dict:
+    """Parse the given string as a dictionary.
+
+    If `need_prefix` is provided, only parse the string if it starts with `need_prefix`.
+    """
+    if need_prefix is not None:
+        if not s.startswith(need_prefix):
+            return {}
+        s = s[len(need_prefix) :]
     s = re.sub(r"\btrue\b", "True", s)
     s = re.sub(r"\bfalse\b", "False", s)
     s = re.sub(r"\bnull\b", "None", s)