Commit 3973859

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-11-26 12:04:13
feat(doubao): add doubao seedream text-to-image generation
1 parent ae92406
Changed files (5)
src/llm/doubao/text2img.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import json
+from pathlib import Path
+from random import randint
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import TEXT2IMG
+from llm.contexts import base64_media
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from networking import download_file, hx_req
+from utils import strings_list
+
+
+async def doubao_genimg(client: Client, message: Message, model_id: str, prompt: str, *, silent: bool = False, **kwargs) -> dict:
+    """Doubao image generation.
+
+    https://www.volcengine.com/docs/82379/1541523
+
+    Args:
+        client (Client): The Pyrogram client.
+        message (Message): The trigger message object.
+        prompt (str): Prompt. Defaults to None.
+        silent (bool, optional): Whether to disable progressing. Defaults to False.
+
+    Return:
+        {"error": str}
+    """
+    if not prompt:
+        await message.reply(text="缺少提示词", quote=True)
+        return {}
+    model_name = model_id.split("/")[-1].title()
+    if not silent and kwargs.get("show_progress"):
+        kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n💬提示词: {prompt}", **kwargs))[0]
+    error = ""
+    succ = False
+    config = {"model": model_id, "prompt": prompt, "size": "4K", "watermark": False, "seed": randint(0, 2147483647)}
+    images = await get_ctx_images(client, message)
+    payload = config | {"image": images} if images else config
+    for api_key in strings_list(TEXT2IMG.DOUBAO_API_KEY, shuffle=True):
+        headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+        api_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
+        resp = await hx_req(api_url, "POST", json_data=payload, headers=headers, proxy=TEXT2IMG.DOUBAO_PROXY, check_keys=["data"])
+        if url := glom(resp, "data.0.url", default=""):
+            img_path = await download_file(url, proxy=TEXT2IMG.DOUBAO_PROXY)
+            if Path(img_path).is_file():
+                caption = f"[下载原图]({url}) (24h内有效)\n{json.dumps(config, ensure_ascii=False, indent=2)}"
+                await send2tg(client, message, texts=caption, media=[{"photo": img_path}], **kwargs)
+                succ = True
+                break
+        elif error_msg := glom(resp, "data.error.message", default=""):
+            await modify_progress(text=f"❌{error_msg}", force_update=True, **kwargs)
+            logger.error(error)
+            continue
+    if succ:
+        await modify_progress(del_status=True, **kwargs)
+    return {"error": error} if error else {}
+
+
+async def get_ctx_images(client: Client, message: Message) -> str | list[str]:
+    """Get image contexts from message."""
+    messages = []
+    if reply_msg := message.reply_to_message:
+        messages.extend(await client.get_media_group(reply_msg.chat.id, reply_msg.id) if reply_msg.media_group_id else [reply_msg])
+    messages.extend(await client.get_media_group(message.chat.id, message.id) if message.media_group_id else [message])
+    images = []
+    for msg in messages:
+        if not msg.photo:
+            continue
+        info = await base64_media(client, msg)
+        images.append(f"data:image/{info['ext']};base64,{info['base64']}")
+    return images[0] if len(images) == 1 else images
src/llm/models.py
@@ -6,7 +6,7 @@ import re
 from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
-from config import GEMINI, GPT, PREFIX, PROXY, TID
+from config import GEMINI, GPT, PREFIX, PROXY, TEXT2IMG, TID
 from llm.contexts import get_conversations
 from llm.utils import BOT_TIPS, enabled_providers, sample_key
 from messages.parser import parse_msg
@@ -150,8 +150,8 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
         model_id = GPT.KIMI_MODEL
     elif startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
         model_id = GPT.GROK_MODEL
-    elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
-        model_id = GEMINI.IMG_MODEL
+    elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG):
+        model_id = TEXT2IMG.DEFAULT_MODEL
         resp_modality = "image"
     # start with /ai, auto detect model_id
     elif startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
src/llm/text2img.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
 from collections import defaultdict
 
 from loguru import logger
@@ -10,6 +9,7 @@ from pyrogram.types import Message
 from config import GEMINI, PREFIX, TEXT2IMG
 from llm.ali.text2img import ali_text2img
 from llm.cloudflare.text2img import cloudflare_text2img
+from llm.doubao.text2img import doubao_genimg
 from llm.gemini.text2img import gemini_text2img
 from llm.utils import enabled_providers
 from utils import strings_list
@@ -18,11 +18,16 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
 `{PREFIX.GENIMG}` 后接提示词即可生成
 
 ⚙️模型配置:
-- `{PREFIX.GENIMG}`: 默认模型 (**{GEMINI.IMG_MODEL}**)
+- `{PREFIX.GENIMG}`: {TEXT2IMG.DEFAULT_MODEL}模型
+- `{PREFIX.GENIMG} @doubao`: Seedream模型
 - `{PREFIX.GENIMG} @flux`: Flux模型
 - `{PREFIX.GENIMG} @sd`: Stable Diffusion模型
+- `{PREFIX.GENIMG} @gemini`: Gemini模型
 
-对于Gemini模型可通过回复消息把历史图片加入上下文, 继续对话以重新修改生成结果
+上下文说明:
+- Gemini模型会把整个回复消息链上的所有消息加入上下文
+- 豆包模型仅会把当前消息及回复的最近一条消息加入上下文
+- 其余模型不会把历史消息加入上下文
 """
 
 
@@ -59,6 +64,8 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
                 return await ali_text2img(client, message, model_id, prompt, **kwargs)
             if provider == "cloudflare":
                 return await cloudflare_text2img(client, message, model_id, prompt, **kwargs)
+            if provider == "doubao":
+                return await doubao_genimg(client, message, model_id, prompt, **kwargs)
         except Exception as e:
             logger.error(e)
     return {}
@@ -88,4 +95,6 @@ def enabled_models() -> dict[str, list]:
             models["flux"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_FLUX_MODEL)])
         if provider == "cloudflare" and TEXT2IMG.CF_STABLE_DIFFUSION_MODEL and "cloudflare" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
             models["sd"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_STABLE_DIFFUSION_MODEL)])
+        if provider == "doubao" and TEXT2IMG.DOUBAO_SEEDREAM_MODEL:
+            models["doubao"].extend([f"doubao/{model}" for model in strings_list(TEXT2IMG.DOUBAO_SEEDREAM_MODEL)])
     return models
src/llm/utils.py
@@ -51,6 +51,8 @@ def enabled_providers() -> tuple[list[str], list[str]]:
         img_providers.append("ali")
     if all([TEXT2IMG.CF_API_KEY]):
         img_providers.append("cloudflare")
+    if all([TEXT2IMG.DOUBAO_API_KEY, TEXT2IMG.DOUBAO_SEEDREAM_MODEL]):
+        img_providers.append("doubao")
     return text_providers, img_providers
 
 
src/config.py
@@ -494,3 +494,6 @@ class TEXT2IMG:
     CF_FLUX_MODEL = os.getenv("TEXT2IMG_CF_FLUX_MODEL", "@cf/black-forest-labs/flux-1-schnell")
     CF_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_CF_STABLE_DIFFUSION_MODEL", "@cf/bytedance/stable-diffusion-xl-lightning")
     CF_PROXY = os.getenv("TEXT2IMG_CF_PROXY", None)
+    DOUBAO_API_KEY = os.getenv("TEXT2IMG_DOUBAO_API_KEY", "")  # comma separated keys
+    DOUBAO_SEEDREAM_MODEL = os.getenv("TEXT2IMG_DOUBAO_SEEDREAM_MODEL", "doubao-seedream-4-0-250828")
+    DOUBAO_PROXY = os.getenv("TEXT2IMG_DOUBAO_PROXY", None)