Commit 0b773fa
Changed files (6)
src
llm
src/llm/ali/text2img.py
@@ -12,6 +12,7 @@ from pyrogram.client import Client
from pyrogram.types import Message
from config import TEXT2IMG
+from llm.utils import parse_as_dict
from messages.progress import modify_progress
from messages.sender import send2tg
from networking import download_file, hx_req
@@ -40,20 +41,21 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
return {}
model_name = model_id.split("/")[-1].title()
if not silent and kwargs.get("show_progress"):
- kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\nπ¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+ kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\n{prompt}", **kwargs))[0]
error = ""
succ = False
+ parsed = parse_as_dict(prompt, need_prefix="config=")
payload = {
"model": model_id,
- "input": {"prompt": prompt},
+ "input": {"prompt": glom(parsed, "input.prompt", default=prompt)},
"parameters": {
- "size": "1024*1024",
- "steps": 50,
- "seed": randint(0, 2147483647),
+ "size": glom(parsed, "parameters.size", default="1024*1024"),
+ "steps": glom(parsed, "parameters.steps", default=50),
+ "seed": glom(parsed, "parameters.seed", default=randint(0, 2147483647)),
},
}
if "stable-diffusion" in model_id:
- payload |= {"parameters": {"n": 4}}
+ payload |= {"parameters": {"n": glom(parsed, "parameters.n", default=4)}}
for api_key in strings_list(TEXT2IMG.ALI_API_KEY, shuffle=True):
headers = {
"X-DashScope-Async": "enable",
@@ -75,6 +77,10 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
await send2tg(client, message, texts=json.dumps(payload, ensure_ascii=False, indent=2), media=media, **kwargs)
succ = True
break
+ if finished.get("error"):
+ error = finished["error"]
+ logger.error(error)
+ continue
except Exception as e:
logger.error(e)
if error and not succ:
src/llm/ali/zimage.py
@@ -37,7 +37,7 @@ async def zimage_text2img(client: Client, message: Message, prompt: str, *, sile
TEXT2IMG.ZIMAGE_API_URL,
"POST",
headers={"Content-Type": "application/json"},
- json_data=parse_as_dict(prompt) or {"prompt": prompt},
+ json_data=parse_as_dict(prompt, need_prefix="config=") or {"prompt": prompt},
proxy=TEXT2IMG.ZIMAGE_PROXY,
check_kv={"mime_type": "image/png"},
timeout=600,
@@ -49,5 +49,5 @@ async def zimage_text2img(client: Client, message: Message, prompt: str, *, sile
async with await anyio.open_file(save_path, "wb") as f:
await f.write(image_bytes)
media = [{"photo": save_path.as_posix()}]
- await send2tg(client, message, texts=json.dumps(resp["params"], ensure_ascii=False, indent=2), media=media, **kwargs)
+ await send2tg(client, message, texts="π **Z-Image**:\n" + json.dumps(resp["params"], ensure_ascii=False, indent=2), media=media, **kwargs)
await modify_progress(del_status=True, **kwargs)
src/llm/cloudflare/text2img.py
@@ -33,7 +33,7 @@ async def cloudflare_text2img(client: Client, message: Message, model_id: str, p
model_name = model_id.split("/")[-1].title()
if not silent and kwargs.get("show_progress"):
- kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\nπ¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+ kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\n{prompt}", **kwargs))[0]
for api_key in strings_list(TEXT2IMG.CF_API_KEY, shuffle=True):
account_id, token = api_key.split(":")
resp = await hx_req(
@@ -45,11 +45,14 @@ async def cloudflare_text2img(client: Client, message: Message, model_id: str, p
proxy=TEXT2IMG.CF_PROXY,
rformat="content",
)
+ if error := resp.get("hx_raw"):
+ await modify_progress(text="βηζε€±θ΄₯\n" + json.dumps(error, ensure_ascii=False, indent=2), force_update=True, **kwargs)
+ continue
path = save_img(resp["content"])
if path.is_file():
- await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=[{"photo": path}], **kwargs)
+ await send2tg(client, message, texts=f"π **{model_name}**:\n{prompt}", media=[{"photo": path}], **kwargs)
+ await modify_progress(del_status=True, **kwargs)
break
- await modify_progress(del_status=True, **kwargs)
return {}
src/llm/doubao/text2img.py
@@ -11,6 +11,7 @@ from pyrogram.types import Message
from config import TEXT2IMG
from llm.contexts import base64_media
+from llm.utils import parse_as_dict
from messages.progress import modify_progress
from messages.sender import send2tg
from networking import download_file, hx_req
@@ -36,25 +37,32 @@ async def doubao_genimg(client: Client, message: Message, model_id: str, prompt:
return {}
model_name = model_id.split("/")[-1].title()
if not silent and kwargs.get("show_progress"):
- kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\nπ¬ζη€Ίθ―: {prompt}", **kwargs))[0]
+ kwargs["progress"] = (await send2tg(client, message, texts=f"π **{model_name}**:\n{prompt}", **kwargs))[0]
error = ""
succ = False
- config = {"model": model_id, "prompt": prompt, "size": "4K", "watermark": False, "seed": randint(0, 2147483647)}
+ parsed = parse_as_dict(prompt, need_prefix="config=")
+ config = {
+ "model": model_id,
+ "prompt": parsed.get("prompt", prompt),
+ "size": parsed.get("size", "4K"),
+ "watermark": parsed.get("watermark", False),
+ "seed": parsed.get("seed", randint(0, 2147483647)),
+ }
images = await get_ctx_images(client, message)
payload = config | {"image": images} if images else config
for api_key in strings_list(TEXT2IMG.DOUBAO_API_KEY, shuffle=True):
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
- resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, check_keys=["data"])
+ resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, max_retry=0)
if url := glom(resp, "data.0.url", default=""):
img_path = await download_file(url, proxy=TEXT2IMG.DOUBAO_PROXY)
if Path(img_path).is_file():
- caption = f"[δΈθ½½εεΎ]({url}) (24hε
ζζ)\n{json.dumps(config, ensure_ascii=False, indent=2)}"
+ caption = f"π **{model_name}**:\n{json.dumps(config, ensure_ascii=False, indent=2)}\n[δΈθ½½εεΎ]({url}) (24hε
ζζ)"
await send2tg(client, message, texts=caption, media=[{"photo": img_path}], **kwargs)
succ = True
break
- elif error_msg := glom(resp, "data.error.message", default=""):
- await modify_progress(text=f"β{error_msg}", force_update=True, **kwargs)
+ elif error := resp.get("hx_raw"):
+ await modify_progress(text=f"βηζε€±θ΄₯\n{json.dumps(error, ensure_ascii=False, indent=2)}", force_update=True, **kwargs)
logger.error(error)
continue
if succ:
src/llm/gemini/text2img.py
@@ -60,6 +60,9 @@ async def gemini_text2img(
await app.aio.aclose()
caption = ""
media = []
+ if glom(response, "candidates.0.finish_reason.name", default="STOP") != "STOP":
+ await modify_progress(text="βηζε€±θ΄₯: " + glom(response, "candidates.0.finish_reason.name"), **kwargs)
+ continue
for part in flatten(glom(response, "candidates.*.content.parts", default=[])):
if part.text:
caption += part.text
src/llm/utils.py
@@ -105,8 +105,15 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
return 0
-def parse_as_dict(s: str) -> dict:
- """Parse the given string as a dictionary."""
+def parse_as_dict(s: str, need_prefix: str | None = None) -> dict:
+ """Parse the given string as a dictionary.
+
+ If `need_prefix` is provided, only parse the string if it starts with `need_prefix`.
+ """
+ if need_prefix is not None:
+ if not s.startswith(need_prefix):
+ return {}
+ s = s[len(need_prefix) :]
s = re.sub(r"\btrue\b", "True", s)
s = re.sub(r"\bfalse\b", "False", s)
s = re.sub(r"\bnull\b", "None", s)