Commit 9a02896

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-02-12 07:32:07
feat(ai): add specify aspect ratio in image generation prompts
1 parent 4d440ce
src/ai/images/gemini.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import re
 from pathlib import Path
 from typing import Any
 
@@ -12,8 +11,9 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
+from ai.images.utils import extract_aspect_ratio
 from ai.texts.contexts import base64_media
-from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval
+from ai.utils import clean_cmd_prefix, literal_eval
 from config import AI, DOWNLOAD_DIR, PROXY
 from messages.modify import message_modify
 from messages.progress import modify_progress
@@ -33,11 +33,11 @@ async def gemini_image_generation(
     gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
     gemini_generate_content_config: str | dict = "",
     gemini_proxy: str | None = PROXY.GOOGLE,
-    support_reference_images: int = 0,
+    max_reference_img: int = 0,
     **kwargs,
 ) -> bool:
     """Get Gemini Image Generation."""
-    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
+    status_msg = kwargs.get("progress") or await message.reply(f"🍌**{model_name}**:\n正在生成图像...", quote=True)
 
     for api_key in strings_list(gemini_api_keys, shuffle=True):
         try:
@@ -45,14 +45,15 @@ async def gemini_image_generation(
             gemini = genai.Client(api_key=api_key, http_options=http_options)
             params: dict[str, Any] = {
                 "model": model_id,
-                "contents": await get_gemini_contexts(client, message, model_name, support_reference_images=support_reference_images),
+                "contents": await get_gemini_contexts(client, message, model_name, max_reference_img=max_reference_img),
                 "config": {"response_modalities": ["TEXT", "IMAGE"]},
             }
             if conf := literal_eval(gemini_generate_content_config):
                 params["config"] |= conf
             image_config = glom(params, "config.image_config", default={})
             if not image_config.get("aspect_ratio"):
-                image_config["aspect_ratio"] = infer_aspect_ratio(message.content)
+                aspect_ratio, _ = extract_aspect_ratio(message.content)
+                image_config["aspect_ratio"] = aspect_ratio or "16:9"
                 params["config"]["image_config"] = image_config
             logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
             response = await gemini.aio.models.generate_content(**params)
@@ -60,7 +61,7 @@ async def gemini_image_generation(
             texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
             media = await download_generated_images(parts)
             if media:
-                await send2tg(client, message, texts=f"{EMOJI_IMG_BOT}**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+                await send2tg(client, message, texts=f"🍌**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
                 await delete_message(status_msg)
                 return True
         except Exception as e:
@@ -69,7 +70,7 @@ async def gemini_image_generation(
     return False
 
 
-async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, support_reference_images: int = 0) -> list[dict]:
+async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
     """Generate Gemini image generation contexts.
 
     Returns:
@@ -80,10 +81,13 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
     """
 
     def clean(text: str) -> str:
-        text = clean_cmd_prefix(text)
-        return text.removeprefix(f"{EMOJI_IMG_BOT}{model_name}:").lstrip()
+        if not text:
+            return ""
+        text = clean_cmd_prefix(str(text))
+        _, text = extract_aspect_ratio(text)
+        return text.removeprefix(f"🍌{model_name}:").lstrip()
 
-    if not support_reference_images:
+    if not max_reference_img:
         return [{"text": clean(message.content)}]
     messages = [message]
     while message.reply_to_message:
@@ -94,12 +98,12 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
     num_img = 0
     for m in messages:  # new to old
         group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
-        role = "model" if any(startswith_prefix(msg.content, f"{EMOJI_IMG_BOT}{model_name}:") for msg in group_messages) else "user"
+        role = "model" if any(startswith_prefix(msg.content, f"🍌{model_name}:") for msg in group_messages) else "user"
         parts = []
         for msg in group_messages[::-1]:  # new to old
             if prompt := clean(msg.content):
                 parts.append({"text": prompt})
-            if msg.photo and num_img < support_reference_images:
+            if msg.photo and num_img < max_reference_img:
                 res = await base64_media(client, msg)
                 parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
                 num_img += 1
@@ -135,36 +139,3 @@ async def download_generated_images(parts: list[dict]) -> list[dict]:
             await f.write(data)
         images.append({"photo": save_path.as_posix()})
     return images
-
-
-def infer_aspect_ratio(text: str) -> str:
-    """Infer aspect ratio from text.
-
-    Args:
-        text (str): Text.
-
-    Returns:
-        str: Aspect ratio. (width:height)
-    """
-    r"""
-    (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
-    (aspect_ratio|aspect ratio|长宽比) : 匹配 "aspect_ratio", "aspect ratio" 或 "长宽比"
-    \s*                   : 匹配零个或多个空格
-    [=::]                  : 匹配等号 `=` 或冒号 `:` 或全角冒号 `:`
-    \s*                   : 匹配零个或多个空格
-    "?                    : 匹配一个可选的双引号 (0 或 1 次)
-    (\d+:\d+)             : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
-    "?                    : 匹配一个可选的双引号 (0 或 1 次)
-    """  # noqa: RUF001
-    pattern = r"(?i)(aspect_ratio|aspect ratio|长宽比)(.*?)[=::]\s*\"?(\d+:\d+)\"?"  # noqa: RUF001
-    if match := re.search(pattern, text):
-        return match.group(3)
-
-    text = text.lower()
-    if "portrait" in text:
-        return "9:16"
-    if "landscape" in text:
-        return "16:9"
-    if "square" in text:
-        return "1:1"
-    return "9:16"  # default
src/ai/images/models.py
@@ -24,7 +24,7 @@ async def get_image_model_configs(message: Message) -> list[dict]:
                 "model_name": "Seedream-4.5",
                 "api_type": "openai",
                 "api_keys": "key1,key2,key3,...",
-                "support_reference_images": true,
+                "max_reference_img": 14,
                 "client_config": { "base_url": "https://ark.cn-beijing.volces.com/api/v3" },
                 "generate_config": {
                     "size": "4K",
src/ai/images/openai_img.py
@@ -3,6 +3,7 @@
 import asyncio
 import base64
 from pathlib import Path
+from typing import Literal
 
 import anyio
 from glom import glom
@@ -11,8 +12,9 @@ from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from ai.texts.contexts import base64_media
-from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval, prettify, trim_none
+from ai.images.post import get_image_contexts
+from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
+from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
 from config import DOWNLOAD_DIR, PROXY
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -31,7 +33,11 @@ async def openai_image_generation(
     client_config: str | dict = "",
     generate_config: str | dict = "",
     proxy: str | None = PROXY.OPENAI,
-    support_reference_images: bool = False,
+    max_reference_img: int = 0,
+    resolution: Literal["1K", "2K", "4K"] = "1K",
+    max_width: int = int(1e16),
+    max_height: int = int(1e16),
+    max_size: int = int(1e32),
     **kwargs,
 ) -> bool:
     """Get OpenAI Image Generation."""
@@ -42,13 +48,17 @@ async def openai_image_generation(
             openai_client |= literal_eval(client_config)
         if proxy:
             openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
-        prompt, reference_images = await get_openai_image_contexts(client, message, support_reference_images=support_reference_images)
+        prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
         if not prompt:
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
             return False
         params = {}
         if literal_eval(generate_config):
             params |= literal_eval(generate_config)
+        aspect_ratio, _ = extract_aspect_ratio(message.content)
+        if aspect_ratio:
+            width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
+            params |= {"size": f"{width}x{height}"}
         if reference_images:
             params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
         params |= {"model": model_id, "prompt": prompt}
@@ -83,32 +93,6 @@ async def openai_image_generation(
     return False
 
 
-async def get_openai_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
-    """Generate OpenAI image generation contexts.
-
-    Returns:
-        tuple: prompt, list_of_images
-    """
-    if not support_reference_images:
-        return clean_cmd_prefix(message.content), []
-    messages = [message]
-    while message.reply_to_message:
-        message = message.reply_to_message
-        messages.append(message)
-    messages.reverse()  # old to new
-    images = []
-    prompt = ""
-    for m in messages:
-        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
-        for msg in group_messages:
-            prompt = clean_cmd_prefix(msg.content)
-            if not msg.photo:
-                continue
-            res = await base64_media(client, msg)
-            images.append(f"data:image/{res['ext']};base64,{res['base64']}")
-    return prompt, images
-
-
 async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
     """Download generated images.
 
src/ai/images/post.py
@@ -3,6 +3,7 @@
 import asyncio
 import base64
 from pathlib import Path
+from typing import Literal
 
 import anyio
 from glom import glom
@@ -10,9 +11,11 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
+from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
 from ai.texts.contexts import base64_media
 from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, prettify, replace_placeholder
 from config import DOWNLOAD_DIR, PROXY
+from messages.modify import message_modify
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import delete_message
@@ -31,13 +34,17 @@ async def http_post_image_generation(
     body: dict | None = None,
     extra_params: dict | None = None,
     proxy: str | None = PROXY.AI_POST,
-    support_reference_images: bool = False,
+    max_reference_img: int = 0,
+    resolution: Literal["1K", "2K", "4K"] = "1K",
+    max_width: int = int(1e16),
+    max_height: int = int(1e16),
+    max_size: int = int(1e32),
     **kwargs,
 ) -> bool:
     """Get HTTP Post Image Generation."""
     status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
     try:
-        prompt, reference_images = await get_image_contexts(client, message, support_reference_images=support_reference_images)
+        prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
         if not prompt:
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
             return False
@@ -50,6 +57,10 @@ async def http_post_image_generation(
         url = base_url + api_paths.get("img_gen", "") if not reference_images else base_url + api_paths.get("img_edit", "")
         params |= {"url": url, "method": "POST"}
         if body:
+            aspect_ratio, _ = extract_aspect_ratio(message.content)
+            if aspect_ratio:
+                width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
+                body |= {"size": f"{width}x{height}"}
             params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt})}
         if extra_params:
             params |= extra_params
@@ -96,26 +107,35 @@ def extract_metadata(response: dict) -> str:
     return ""
 
 
-async def get_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
+async def get_image_contexts(client: Client, message: Message, *, max_reference_img: int = 0) -> tuple[str, list[str]]:
     """Get image generation contexts.
 
     Returns:
         tuple: prompt, list_of_images
     """
-    if not support_reference_images:
-        return clean_cmd_prefix(message.content), []
+
+    def clean(text: str) -> str:
+        if not text:
+            return ""
+        text = clean_cmd_prefix(str(text))
+        _, text = extract_aspect_ratio(text)
+        return text.strip()
+
+    if not max_reference_img:
+        return clean(message.content), []
     messages = [message]
     while message.reply_to_message:
         message = message.reply_to_message
-        messages.append(message)
+        if not message.service:  # ignore service messages
+            messages.append(message_modify(message))
     messages.reverse()  # old to new
     images = []
     prompt = ""
     for m in messages:
         group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
         for msg in group_messages:
-            prompt = clean_cmd_prefix(msg.content)
-            if not msg.photo:
+            prompt = clean(msg.content)
+            if not msg.photo or len(images) >= max_reference_img:
                 continue
             res = await base64_media(client, msg)
             images.append(f"data:image/{res['ext']};base64,{res['base64']}")
src/ai/images/utils.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import math
+import re
+from typing import Literal
+
+from ai.utils import clean_cmd_prefix
+
+
+def extract_aspect_ratio(text: str) -> tuple[str, str]:
+    """Infer aspect ratio from text.
+
+    If prompt startswith "width:height", set aspect_ratio to "width:height"
+    and remove "width:height" from prompt.
+
+    Match:
+    aspect_ratio = 9:16
+    "aspect_ratio" : "16:9"
+    aspect_ratio 9:16
+    portrait
+    landscape
+    square
+    16:9 prompt...
+
+    Args:
+        text (str): Text.
+
+    Returns:
+        tuple[str, str]: 1. Aspect ratio. (width:height)
+                         2. Prompt.
+    """
+    # ruff: noqa: RUF001
+    text = clean_cmd_prefix(text)
+    # text startswith "width:height"
+    if match := re.match(r"(\d+\s*[::]\s*\d+)", text):
+        return match.group(1), text.removeprefix(match.group(1)).lstrip()
+
+    r"""
+    (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
+    (aspect_ratio|aspect ratio|ar|宽高比) : 匹配 "aspect_ratio", "aspect ratio", "ar" 或 "宽高比"
+    [\s=::\"]*                  : 匹配空格 or 等号 `=` or 冒号 `:` or 全角冒号 `:` or 双引号 `"` (0 或多次)
+    (\d+:\d+)             : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
+    """
+    pattern = r"(?i)(aspect_ratio|aspect ratio|ar|宽高比)[\s=::\"]*(\d+:\d+)"
+    if match := re.search(pattern, text):
+        return match.group(2), text
+
+    text = text.lower()
+    if "portrait" in text:
+        return "9:16", text
+    if "landscape" in text:
+        return "16:9", text
+    if "square" in text:
+        return "1:1", text
+    return "", text  # default
+
+
+def aspect_ratio_to_size(
+    aspect_ratio: str,
+    resolution: Literal["1K", "2K", "4K"] = "1K",
+    max_width: int = int(1e16),
+    max_height: int = int(1e16),
+    max_size: int = int(1e32),
+) -> tuple[int, int]:
+    """Convert aspect ratio to image size (width, height)."""
+    width, height = 1024, 1024
+    if resolution.upper() == "1K":
+        match aspect_ratio:
+            case "1:1":
+                width, height = 1024, 1024
+            case "2:3":
+                width, height = 832, 1248
+            case "3:2":
+                width, height = 1248, 832
+            case "3:4":
+                width, height = 864, 1152
+            case "4:3":
+                width, height = 1152, 864
+            case "4:5":
+                width, height = 928, 1152
+            case "5:4":
+                width, height = 1152, 928
+            case "9:16":
+                width, height = 720, 1280
+            case "16:9":
+                width, height = 1280, 720
+            case "21:9":
+                width, height = 1512, 648
+            case _:
+                width, height = 1024, 1024
+    elif resolution.upper() == "2K":
+        match aspect_ratio:
+            case "1:1":
+                width, height = 2048, 2048
+            case "2:3":
+                width, height = 1664, 2496
+            case "3:2":
+                width, height = 2496, 1664
+            case "3:4":
+                width, height = 1728, 2304
+            case "4:3":
+                width, height = 2304, 1728
+            case "4:5":
+                width, height = 1856, 2304
+            case "5:4":
+                width, height = 2304, 1856
+            case "9:16":
+                width, height = 1440, 2560
+            case "16:9":
+                width, height = 2560, 1440
+            case "21:9":
+                width, height = 3024, 1296
+            case _:
+                width, height = 2048, 2048
+    elif resolution.upper() == "4K":
+        match aspect_ratio:
+            case "1:1":
+                width, height = 4096, 4096
+            case "2:3":
+                width, height = 3328, 4992
+            case "3:2":
+                width, height = 4992, 3328
+            case "3:4":
+                width, height = 3456, 4608
+            case "4:3":
+                width, height = 4608, 3456
+            case "4:5":
+                width, height = 3648, 4560
+            case "5:4":
+                width, height = 4560, 3648
+            case "9:16":
+                width, height = 2880, 5120
+            case "16:9":
+                width, height = 5120, 2880
+            case "21:9":
+                width, height = 6048, 2592
+            case _:
+                width, height = 4096, 4096
+    return adjust_size(width, height, max_width, max_height, max_size)
+
+
+def adjust_size(width: int, height: int, max_width: int = int(1e16), max_height: int = int(1e16), max_size: int = int(1e32)) -> tuple[int, int]:
+    """Adjust image size to fit within max_width, max_height, max_size.
+
+    Args:
+        width (int): Image width.
+        height (int): Image height.
+        max_width (int, optional): Max width. Defaults to int(1E16).
+        max_height (int, optional): Max height. Defaults to int(1E16).
+        max_size (int, optional): Max size (width * height). Defaults to int(1E32).
+
+    Returns:
+        tuple[int, int]: Adjusted width, height.
+    """
+    # 1. Scale down to fit within max_size
+    scale = min(max_size / (width * height), 1.0)
+    width = math.floor(width * scale)
+    height = math.floor(height * scale)
+
+    # 2. Scale down to fit within max_width, max_height
+    scale = min(max_width / width, max_height / height, 1.0)
+    width = math.floor(width * scale)
+    height = math.floor(height * scale)
+
+    return width, height
src/ai/utils.py
@@ -15,12 +15,13 @@ from config import AI, PREFIX, PROXY
 from database.kv import get_cf_kv
 from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
 
+# ruff: noqa: RUF001
 EMOJI_TEXT_BOT = "🤖"
 EMOJI_IMG_BOT = "🌠"
 EMOJI_VIDEO_BOT = "📽"
 EMOJI_REASONING_BEGIN = "🤔"  # use emoji to separate model reasoning and content
 EMOJI_REASONING_END = "💡"
-BOT_TIPS = "(回复以继续)"  # noqa: RUF001
+BOT_TIPS = "(回复以继续)"
 
 
 async def text_generation_docs() -> str:
@@ -53,7 +54,7 @@ def trim_none(obj: dict) -> dict:
     if isinstance(obj, dict):
         return {k: trim_none(v) for k, v in obj.items() if v is not None}
     if isinstance(obj, list):
-        return [trim_none(item) for item in obj if item is not None]
+        return [trim_none(item) for item in obj if item is not None]  # ty:ignore[invalid-return-type]
     return obj