Commit 417b977
Changed files (1)
src
llm
ali
src/llm/ali/text2img.py
@@ -4,6 +4,7 @@ import asyncio
from pathlib import Path
from glom import glom
+from httpx import AsyncClient
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
@@ -37,6 +38,7 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
if not silent and kwargs.get("show_progress"):
kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n💬提示词: {prompt}", **kwargs))[0]
error = ""
+ succ = False
payload = {"model": model_id, "input": {"prompt": prompt}}
if "stable-diffusion" in model_id:
payload |= {"parameters": {"n": 4}}
@@ -46,26 +48,27 @@ async def ali_text2img(client: Client, message: Message, model_id: str, prompt:
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
- resp = await hx_req(
- "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis",
- method="POST",
- headers=headers,
- json_data=payload,
- timeout=10,
- check_kv={"output.task_status": "PENDING"},
- check_keys=["output.task_id"],
- proxy=TEXT2IMG.ALI_PROXY,
- )
- if resp.get("message"):
- error = resp["message"]
- logger.error(error)
- continue
- finished = await wait_for_response(resp["output"]["task_id"], api_key)
- if images := finished.get("images"):
- media = [{"photo": img} for img in images]
- await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=media, **kwargs)
- break
- await modify_progress(del_status=True, **kwargs)
+ api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
+ httpx_client = AsyncClient(proxy=TEXT2IMG.ALI_PROXY, headers=headers, timeout=10)
+ try:
+ response = await httpx_client.post(api_url, json=payload)
+ resp = response.json()
+ if resp.get("message"):
+ error = resp["message"]
+ logger.error(error)
+ continue
+ finished = await wait_for_response(resp["output"]["task_id"], api_key)
+ if images := finished.get("images"):
+ media = [{"photo": img} for img in images]
+ await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=media, **kwargs)
+ succ = True
+ break
+ except Exception as e:
+ logger.error(e)
+ if error and not succ:
+ await modify_progress(text=f"❌{error}", force_update=True, **kwargs)
+ else:
+ await modify_progress(del_status=True, **kwargs)
return {"error": error} if error else {}
@@ -80,7 +83,7 @@ async def wait_for_response(task_id: str, api_key: str) -> dict:
if task_status == "SUCCESS":
resp = await hx_req(api, headers=headers, silent=True, proxy=TEXT2IMG.ALI_PROXY, check_keys=["output.result"])
return resp["output"]["result"]
- while task_status == "RUNNING":
+ while task_status in ["PENDING", "RUNNING"]:
await asyncio.sleep(1)
logger.trace(f"Waiting for Ali Text2IMG, TaskID: {task_id}")
resp = await hx_req(api, headers=headers, silent=True, proxy=TEXT2IMG.ALI_PROXY, check_keys=["output.task_status"])