Commit 3973859
Changed files (5)
src
src/llm/doubao/text2img.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import json
+from pathlib import Path
+from random import randint
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import TEXT2IMG
+from llm.contexts import base64_media
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from networking import download_file, hx_req
+from utils import strings_list
+
+
+async def doubao_genimg(client: Client, message: Message, model_id: str, prompt: str, *, silent: bool = False, **kwargs) -> dict:
+ """Doubao image generation.
+
+ https://www.volcengine.com/docs/82379/1541523
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ prompt (str): Prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+
+ Return:
+ {"error": str}
+ """
+ if not prompt:
+ await message.reply(text="缺少提示词", quote=True)
+ return {}
+ model_name = model_id.split("/")[-1].title()
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n💬提示词: {prompt}", **kwargs))[0]
+ error = ""
+ succ = False
+ config = {"model": model_id, "prompt": prompt, "size": "4K", "watermark": False, "seed": randint(0, 2147483647)}
+ images = await get_ctx_images(client, message)
+ payload = config | {"image": images} if images else config
+ for api_key in strings_list(TEXT2IMG.DOUBAO_API_KEY, shuffle=True):
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
+ resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, check_keys=["data"])
+ if url := glom(resp, "data.0.url", default=""):
+ img_path = await download_file(url, proxy=TEXT2IMG.DOUBAO_PROXY)
+ if Path(img_path).is_file():
+ caption = f"[下载原图]({url}) (24h内有效)\n{json.dumps(config, ensure_ascii=False, indent=2)}"
+ await send2tg(client, message, texts=caption, media=[{"photo": img_path}], **kwargs)
+ succ = True
+ break
+ elif error_msg := glom(resp, "data.error.message", default=""):
+ await modify_progress(text=f"❌{error_msg}", force_update=True, **kwargs)
+ logger.error(error)
+ continue
+ if succ:
+ await modify_progress(del_status=True, **kwargs)
+ return {"error": error} if error else {}
+
+
+async def get_ctx_images(client: Client, message: Message) -> str | list[str]:
+ """Get image contexts from message."""
+ messages = []
+ if reply_msg := message.reply_to_message:
+ messages.extend(await client.get_media_group(reply_msg.chat.id, reply_msg.id) if reply_msg.media_group_id else [reply_msg])
+ messages.extend(await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message])
+ images = []
+ for msg in messages:
+ if not msg.photo:
+ continue
+ info = await base64_media(client, msg)
+ images.append(f"data:image/{info['ext']};base64,{info['base64']}")
+ return images[0] if len(images) == 1 else images
src/llm/models.py
@@ -6,7 +6,7 @@ import re
from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
-from config import GEMINI, GPT, PREFIX, PROXY, TID
+from config import GEMINI, GPT, PREFIX, PROXY, TEXT2IMG, TID
from llm.contexts import get_conversations
from llm.utils import BOT_TIPS, enabled_providers, sample_key
from messages.parser import parse_msg
@@ -150,8 +150,8 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
model_id = GPT.KIMI_MODEL
elif startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
model_id = GPT.GROK_MODEL
- elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
- model_id = GEMINI.IMG_MODEL
+ elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG):
+ model_id = TEXT2IMG.DEFAULT_MODEL
resp_modality = "image"
# start with /ai, auto detect model_id
elif startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
src/llm/text2img.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-
from collections import defaultdict
from loguru import logger
@@ -10,6 +9,7 @@ from pyrogram.types import Message
from config import GEMINI, PREFIX, TEXT2IMG
from llm.ali.text2img import ali_text2img
from llm.cloudflare.text2img import cloudflare_text2img
+from llm.doubao.text2img import doubao_genimg
from llm.gemini.text2img import gemini_text2img
from llm.utils import enabled_providers
from utils import strings_list
@@ -18,11 +18,16 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
`{PREFIX.GENIMG}` 后接提示词即可生成
⚙️模型配置:
-- `{PREFIX.GENIMG}`: 默认模型 (**{GEMINI.IMG_MODEL}**)
+- `{PREFIX.GENIMG}`: {TEXT2IMG.DEFAULT_MODEL}模型
+- `{PREFIX.GENIMG} @doubao`: Seedream模型
- `{PREFIX.GENIMG} @flux`: Flux模型
- `{PREFIX.GENIMG} @sd`: Stable Diffusion模型
+- `{PREFIX.GENIMG} @gemini`: Gemini模型
-对于Gemini模型可通过回复消息把历史图片加入上下文, 继续对话以重新修改生成结果
+上下文说明:
+- Gemini模型会把整个回复消息链上的所有消息加入上下文
+- 豆包模型仅会把当前消息及回复的最近一条消息加入上下文
+- 其余模型不会把历史消息加入上下文
"""
@@ -59,6 +64,8 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
return await ali_text2img(client, message, model_id, prompt, **kwargs)
if provider == "cloudflare":
return await cloudflare_text2img(client, message, model_id, prompt, **kwargs)
+ if provider == "doubao":
+ return await doubao_genimg(client, message, model_id, prompt, **kwargs)
except Exception as e:
logger.error(e)
return {}
@@ -88,4 +95,6 @@ def enabled_models() -> dict[str, list]:
models["flux"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_FLUX_MODEL)])
if provider == "cloudflare" and TEXT2IMG.CF_STABLE_DIFFUSION_MODEL and "cloudflare" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
models["sd"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_STABLE_DIFFUSION_MODEL)])
+ if provider == "doubao" and TEXT2IMG.DOUBAO_SEEDREAM_MODEL:
+ models["doubao"].extend([f"doubao/{model}" for model in strings_list(TEXT2IMG.DOUBAO_SEEDREAM_MODEL)])
return models
src/llm/utils.py
@@ -51,6 +51,8 @@ def enabled_providers() -> tuple[list[str], list[str]]:
img_providers.append("ali")
if all([TEXT2IMG.CF_API_KEY]):
img_providers.append("cloudflare")
+ if all([TEXT2IMG.DOUBAO_API_KEY, TEXT2IMG.DOUBAO_SEEDREAM_MODEL]):
+ img_providers.append("doubao")
return text_providers, img_providers
src/config.py
@@ -494,3 +494,6 @@ class TEXT2IMG:
CF_FLUX_MODEL = os.getenv("TEXT2IMG_CF_FLUX_MODEL", "@cf/black-forest-labs/flux-1-schnell")
CF_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_CF_STABLE_DIFFUSION_MODEL", "@cf/bytedance/stable-diffusion-xl-lightning")
CF_PROXY = os.getenv("TEXT2IMG_CF_PROXY", None)
+ DOUBAO_API_KEY = os.getenv("TEXT2IMG_DOUBAO_API_KEY", "") # comma separated keys
+ DOUBAO_SEEDREAM_MODEL = os.getenv("TEXT2IMG_DOUBAO_SEEDREAM_MODEL", "doubao-seedream-4-0-250828")
+ DOUBAO_PROXY = os.getenv("TEXT2IMG_DOUBAO_PROXY", None)