Commit 9a02896
src/ai/images/gemini.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import re
from pathlib import Path
from typing import Any
@@ -12,8 +11,9 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.images.utils import extract_aspect_ratio
from ai.texts.contexts import base64_media
-from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval
+from ai.utils import clean_cmd_prefix, literal_eval
from config import AI, DOWNLOAD_DIR, PROXY
from messages.modify import message_modify
from messages.progress import modify_progress
@@ -33,11 +33,11 @@ async def gemini_image_generation(
gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
gemini_generate_content_config: str | dict = "",
gemini_proxy: str | None = PROXY.GOOGLE,
- support_reference_images: int = 0,
+ max_reference_img: int = 0,
**kwargs,
) -> bool:
"""Get Gemini Image Generation."""
- status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
+ status_msg = kwargs.get("progress") or await message.reply(f"🍌**{model_name}**:\n正在生成图像...", quote=True)
for api_key in strings_list(gemini_api_keys, shuffle=True):
try:
@@ -45,14 +45,15 @@ async def gemini_image_generation(
gemini = genai.Client(api_key=api_key, http_options=http_options)
params: dict[str, Any] = {
"model": model_id,
- "contents": await get_gemini_contexts(client, message, model_name, support_reference_images=support_reference_images),
+ "contents": await get_gemini_contexts(client, message, model_name, max_reference_img=max_reference_img),
"config": {"response_modalities": ["TEXT", "IMAGE"]},
}
if conf := literal_eval(gemini_generate_content_config):
params["config"] |= conf
image_config = glom(params, "config.image_config", default={})
if not image_config.get("aspect_ratio"):
- image_config["aspect_ratio"] = infer_aspect_ratio(message.content)
+ aspect_ratio, _ = extract_aspect_ratio(message.content)
+ image_config["aspect_ratio"] = aspect_ratio or "16:9"
params["config"]["image_config"] = image_config
logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
response = await gemini.aio.models.generate_content(**params)
@@ -60,7 +61,7 @@ async def gemini_image_generation(
texts = "".join([p.get("text") or "" for p in parts if not p.get("thought")])
media = await download_generated_images(parts)
if media:
- await send2tg(client, message, texts=f"{EMOJI_IMG_BOT}**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
+ await send2tg(client, message, texts=f"🍌**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
await delete_message(status_msg)
return True
except Exception as e:
@@ -69,7 +70,7 @@ async def gemini_image_generation(
return False
-async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, support_reference_images: int = 0) -> list[dict]:
+async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
"""Generate Gemini image generation contexts.
Returns:
@@ -80,10 +81,13 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
"""
def clean(text: str) -> str:
- text = clean_cmd_prefix(text)
- return text.removeprefix(f"{EMOJI_IMG_BOT}{model_name}:").lstrip()
+ if not text:
+ return ""
+ text = clean_cmd_prefix(str(text))
+ _, text = extract_aspect_ratio(text)
+ return text.removeprefix(f"🍌{model_name}:").lstrip()
- if not support_reference_images:
+ if not max_reference_img:
return [{"text": clean(message.content)}]
messages = [message]
while message.reply_to_message:
@@ -94,12 +98,12 @@ async def get_gemini_contexts(client: Client, message: Message, model_name: str,
num_img = 0
for m in messages: # new to old
group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
- role = "model" if any(startswith_prefix(msg.content, f"{EMOJI_IMG_BOT}{model_name}:") for msg in group_messages) else "user"
+ role = "model" if any(startswith_prefix(msg.content, f"🍌{model_name}:") for msg in group_messages) else "user"
parts = []
for msg in group_messages[::-1]: # new to old
if prompt := clean(msg.content):
parts.append({"text": prompt})
- if msg.photo and num_img < support_reference_images:
+ if msg.photo and num_img < max_reference_img:
res = await base64_media(client, msg)
parts.append({"inline_data": {"mime_type": f"image/{res['ext']}", "data": res["base64"]}})
num_img += 1
@@ -135,36 +139,3 @@ async def download_generated_images(parts: list[dict]) -> list[dict]:
await f.write(data)
images.append({"photo": save_path.as_posix()})
return images
-
-
-def infer_aspect_ratio(text: str) -> str:
- """Infer aspect ratio from text.
-
- Args:
- text (str): Text.
-
- Returns:
- str: Aspect ratio. (width:height)
- """
- r"""
- (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
- (aspect_ratio|aspect ratio|长宽比) : 匹配 "aspect_ratio", "aspect ratio" 或 "长宽比"
- \s* : 匹配零个或多个空格
- [=::] : 匹配等号 `=` 或冒号 `:` 或全角冒号 `:`
- \s* : 匹配零个或多个空格
- "? : 匹配一个可选的双引号 (0 或 1 次)
- (\d+:\d+) : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
- "? : 匹配一个可选的双引号 (0 或 1 次)
- """ # noqa: RUF001
- pattern = r"(?i)(aspect_ratio|aspect ratio|长宽比)(.*?)[=::]\s*\"?(\d+:\d+)\"?" # noqa: RUF001
- if match := re.search(pattern, text):
- return match.group(3)
-
- text = text.lower()
- if "portrait" in text:
- return "9:16"
- if "landscape" in text:
- return "16:9"
- if "square" in text:
- return "1:1"
- return "9:16" # default
src/ai/images/models.py
@@ -24,7 +24,7 @@ async def get_image_model_configs(message: Message) -> list[dict]:
"model_name": "Seedream-4.5",
"api_type": "openai",
"api_keys": "key1,key2,key3,...",
- "support_reference_images": true,
+ "max_reference_img": 14,
"client_config": { "base_url": "https://ark.cn-beijing.volces.com/api/v3" },
"generate_config": {
"size": "4K",
src/ai/images/openai_img.py
@@ -3,6 +3,7 @@
import asyncio
import base64
from pathlib import Path
+from typing import Literal
import anyio
from glom import glom
@@ -11,8 +12,9 @@ from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
from pyrogram.types import Message
-from ai.texts.contexts import base64_media
-from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, literal_eval, prettify, trim_none
+from ai.images.post import get_image_contexts
+from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
+from ai.utils import EMOJI_IMG_BOT, literal_eval, prettify, trim_none
from config import DOWNLOAD_DIR, PROXY
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -31,7 +33,11 @@ async def openai_image_generation(
client_config: str | dict = "",
generate_config: str | dict = "",
proxy: str | None = PROXY.OPENAI,
- support_reference_images: bool = False,
+ max_reference_img: int = 0,
+ resolution: Literal["1K", "2K", "4K"] = "1K",
+ max_width: int = int(1e16),
+ max_height: int = int(1e16),
+ max_size: int = int(1e32),
**kwargs,
) -> bool:
"""Get OpenAI Image Generation."""
@@ -42,13 +48,17 @@ async def openai_image_generation(
openai_client |= literal_eval(client_config)
if proxy:
openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=proxy)}
- prompt, reference_images = await get_openai_image_contexts(client, message, support_reference_images=support_reference_images)
+ prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
if not prompt:
await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
return False
params = {}
if literal_eval(generate_config):
params |= literal_eval(generate_config)
+ aspect_ratio, _ = extract_aspect_ratio(message.content)
+ if aspect_ratio:
+ width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
+ params |= {"size": f"{width}x{height}"}
if reference_images:
params["extra_body"] = params.get("extra_body", {}) | {"image": reference_images}
params |= {"model": model_id, "prompt": prompt}
@@ -83,32 +93,6 @@ async def openai_image_generation(
return False
-async def get_openai_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
- """Generate OpenAI image generation contexts.
-
- Returns:
- tuple: prompt, list_of_images
- """
- if not support_reference_images:
- return clean_cmd_prefix(message.content), []
- messages = [message]
- while message.reply_to_message:
- message = message.reply_to_message
- messages.append(message)
- messages.reverse() # old to new
- images = []
- prompt = ""
- for m in messages:
- group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
- for msg in group_messages:
- prompt = clean_cmd_prefix(msg.content)
- if not msg.photo:
- continue
- res = await base64_media(client, msg)
- images.append(f"data:image/{res['ext']};base64,{res['base64']}")
- return prompt, images
-
-
async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
"""Download generated images.
src/ai/images/post.py
@@ -3,6 +3,7 @@
import asyncio
import base64
from pathlib import Path
+from typing import Literal
import anyio
from glom import glom
@@ -10,9 +11,11 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.images.utils import aspect_ratio_to_size, extract_aspect_ratio
from ai.texts.contexts import base64_media
from ai.utils import EMOJI_IMG_BOT, clean_cmd_prefix, prettify, replace_placeholder
from config import DOWNLOAD_DIR, PROXY
+from messages.modify import message_modify
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import delete_message
@@ -31,13 +34,17 @@ async def http_post_image_generation(
body: dict | None = None,
extra_params: dict | None = None,
proxy: str | None = PROXY.AI_POST,
- support_reference_images: bool = False,
+ max_reference_img: int = 0,
+ resolution: Literal["1K", "2K", "4K"] = "1K",
+ max_width: int = int(1e16),
+ max_height: int = int(1e16),
+ max_size: int = int(1e32),
**kwargs,
) -> bool:
"""Get HTTP Post Image Generation."""
status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
try:
- prompt, reference_images = await get_image_contexts(client, message, support_reference_images=support_reference_images)
+ prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
if not prompt:
await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
return False
@@ -50,6 +57,10 @@ async def http_post_image_generation(
url = base_url + api_paths.get("img_gen", "") if not reference_images else base_url + api_paths.get("img_edit", "")
params |= {"url": url, "method": "POST"}
if body:
+ aspect_ratio, _ = extract_aspect_ratio(message.content)
+ if aspect_ratio:
+ width, height = aspect_ratio_to_size(aspect_ratio, resolution, max_width, max_height, max_size)
+ body |= {"size": f"{width}x{height}"}
params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt})}
if extra_params:
params |= extra_params
@@ -96,26 +107,35 @@ def extract_metadata(response: dict) -> str:
return ""
-async def get_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
+async def get_image_contexts(client: Client, message: Message, *, max_reference_img: int = 0) -> tuple[str, list[str]]:
"""Get image generation contexts.
Returns:
tuple: prompt, list_of_images
"""
- if not support_reference_images:
- return clean_cmd_prefix(message.content), []
+
+ def clean(text: str) -> str:
+ if not text:
+ return ""
+ text = clean_cmd_prefix(str(text))
+ _, text = extract_aspect_ratio(text)
+ return text.strip()
+
+ if not max_reference_img:
+ return clean(message.content), []
messages = [message]
while message.reply_to_message:
message = message.reply_to_message
- messages.append(message)
+ if not message.service: # ignore service messages
+ messages.append(message_modify(message))
messages.reverse() # old to new
images = []
prompt = ""
for m in messages:
group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
for msg in group_messages:
- prompt = clean_cmd_prefix(msg.content)
- if not msg.photo:
+ prompt = clean(msg.content)
+ if not msg.photo or len(images) >= max_reference_img:
continue
res = await base64_media(client, msg)
images.append(f"data:image/{res['ext']};base64,{res['base64']}")
src/ai/images/utils.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import math
+import re
+from typing import Literal
+
+from ai.utils import clean_cmd_prefix
+
+
+def extract_aspect_ratio(text: str) -> tuple[str, str]:
+ """Infer aspect ratio from text.
+
+ If prompt startswith "width:height", set aspect_ratio to "width:height"
+ and remove "width:height" from prompt.
+
+ Match:
+ aspect_ratio = 9:16
+ "aspect_ratio" : "16:9"
+ aspect_ratio 9:16
+ portrait
+ landscape
+ square
+ 16:9 prompt...
+
+ Args:
+ text (str): Text.
+
+ Returns:
+ tuple[str, str]: 1. Aspect ratio. (width:height)
+ 2. Prompt.
+ """
+ # ruff: noqa: RUF001
+ text = clean_cmd_prefix(text)
+ # text startswith "width:height"
+ if match := re.match(r"(\d+\s*[::]\s*\d+)", text):
+ return match.group(1), text.removeprefix(match.group(1)).lstrip()
+
+ r"""
+ (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
+ (aspect_ratio|aspect ratio|ar|宽高比) : 匹配 "aspect_ratio", "aspect ratio", "ar" 或 "宽高比"
+ [\s=::\"]* : 匹配空格 or 等号 `=` or 冒号 `:` or 全角冒号 `:` or 双引号 `"` (0 或多次)
+ (\d+:\d+) : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
+ """
+ pattern = r"(?i)(aspect_ratio|aspect ratio|ar|宽高比)[\s=::\"]*(\d+:\d+)"
+ if match := re.search(pattern, text):
+ return match.group(2), text
+
+ text = text.lower()
+ if "portrait" in text:
+ return "9:16", text
+ if "landscape" in text:
+ return "16:9", text
+ if "square" in text:
+ return "1:1", text
+ return "", text # default
+
+
+def aspect_ratio_to_size(
+ aspect_ratio: str,
+ resolution: Literal["1K", "2K", "4K"] = "1K",
+ max_width: int = int(1e16),
+ max_height: int = int(1e16),
+ max_size: int = int(1e32),
+) -> tuple[int, int]:
+ """Convert aspect ratio to image size (width, height)."""
+ width, height = 1024, 1024
+ if resolution.upper() == "1K":
+ match aspect_ratio:
+ case "1:1":
+ width, height = 1024, 1024
+ case "2:3":
+ width, height = 832, 1248
+ case "3:2":
+ width, height = 1248, 832
+ case "3:4":
+ width, height = 864, 1152
+ case "4:3":
+ width, height = 1152, 864
+ case "4:5":
+ width, height = 928, 1152
+ case "5:4":
+ width, height = 1152, 928
+ case "9:16":
+ width, height = 720, 1280
+ case "16:9":
+ width, height = 1280, 720
+ case "21:9":
+ width, height = 1512, 648
+ case _:
+ width, height = 1024, 1024
+ elif resolution.upper() == "2K":
+ match aspect_ratio:
+ case "1:1":
+ width, height = 2048, 2048
+ case "2:3":
+ width, height = 1664, 2496
+ case "3:2":
+ width, height = 2496, 1664
+ case "3:4":
+ width, height = 1728, 2304
+ case "4:3":
+ width, height = 2304, 1728
+ case "4:5":
+ width, height = 1856, 2304
+ case "5:4":
+ width, height = 2304, 1856
+ case "9:16":
+ width, height = 1440, 2560
+ case "16:9":
+ width, height = 2560, 1440
+ case "21:9":
+ width, height = 3024, 1296
+ case _:
+ width, height = 2048, 2048
+ elif resolution.upper() == "4K":
+ match aspect_ratio:
+ case "1:1":
+ width, height = 4096, 4096
+ case "2:3":
+ width, height = 3328, 4992
+ case "3:2":
+ width, height = 4992, 3328
+ case "3:4":
+ width, height = 3456, 4608
+ case "4:3":
+ width, height = 4608, 3456
+ case "4:5":
+ width, height = 3648, 4560
+ case "5:4":
+ width, height = 4560, 3648
+ case "9:16":
+ width, height = 2880, 5120
+ case "16:9":
+ width, height = 5120, 2880
+ case "21:9":
+ width, height = 6048, 2592
+ case _:
+ width, height = 4096, 4096
+ return adjust_size(width, height, max_width, max_height, max_size)
+
+
+def adjust_size(width: int, height: int, max_width: int = int(1e16), max_height: int = int(1e16), max_size: int = int(1e32)) -> tuple[int, int]:
+ """Adjust image size to fit within max_width, max_height, max_size.
+
+ Args:
+ width (int): Image width.
+ height (int): Image height.
+ max_width (int, optional): Max width. Defaults to int(1E16).
+ max_height (int, optional): Max height. Defaults to int(1E16).
+ max_size (int, optional): Max size (width * height). Defaults to int(1E32).
+
+ Returns:
+ tuple[int, int]: Adjusted width, height.
+ """
+ # 1. Scale down to fit within max_size
+ scale = min(max_size / (width * height), 1.0)
+ width = math.floor(width * scale)
+ height = math.floor(height * scale)
+
+ # 2. Scale down to fit within max_width, max_height
+ scale = min(max_width / width, max_height / height, 1.0)
+ width = math.floor(width * scale)
+ height = math.floor(height * scale)
+
+ return width, height
src/ai/utils.py
@@ -15,12 +15,13 @@ from config import AI, PREFIX, PROXY
from database.kv import get_cf_kv
from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
+# ruff: noqa: RUF001
EMOJI_TEXT_BOT = "🤖"
EMOJI_IMG_BOT = "🌠"
EMOJI_VIDEO_BOT = "📽"
EMOJI_REASONING_BEGIN = "🤔" # use emoji to separate model reasoning and content
EMOJI_REASONING_END = "💡"
-BOT_TIPS = "(回复以继续)" # noqa: RUF001
+BOT_TIPS = "(回复以继续)"
async def text_generation_docs() -> str:
@@ -53,7 +54,7 @@ def trim_none(obj: dict) -> dict:
if isinstance(obj, dict):
return {k: trim_none(v) for k, v in obj.items() if v is not None}
if isinstance(obj, list):
- return [trim_none(item) for item in obj if item is not None]
+ return [trim_none(item) for item in obj if item is not None] # ty:ignore[invalid-return-type]
return obj