Commit 53e5f44
Changed files (5)
src
src/llm/gemini/text2img.py
@@ -1,32 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import contextlib
-import json
+from pathlib import Path
+from typing import TYPE_CHECKING
+from glom import flatten, glom
from google import genai
-from google.genai import types
+from google.genai.types import ContentListUnion, GenerateContentConfig, HttpOptions, Part
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import CAPTION_LENGTH, GEMINI, GPT, TEXT_LENGTH
-from llm.contexts import get_conversation_contexts, get_conversations
-from llm.gemini.utils import parse_response
-from llm.hooks import hook_gemini_httpoptions
-from llm.utils import BOT_TIPS, clean_cmd_prefix, clean_gemini_sourcemarks
-from messages.parser import parse_msg
+from config import DOWNLOAD_DIR, PREFIX, TEXT2IMG
+from llm.contexts import get_conversations
from messages.progress import modify_progress
from messages.sender import send2tg
-from messages.utils import blockquote, count_without_entities, smart_split
-from utils import strings_list
+from messages.utils import remove_prefix
+from others.alias import command_alias
+from utils import rand_number, strings_list
+
+if TYPE_CHECKING:
+ from io import BytesIO
async def gemini_text2img(
client: Client,
message: Message,
+ model_id: str,
+ prompt: str,
*,
- disable_thinking: bool = False,
- system_prompt: str | None = None,
silent: bool = False,
**kwargs,
) -> dict:
@@ -35,92 +36,73 @@ async def gemini_text2img(
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
- disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
- include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
- system_prompt (str | None, optional): System prompt. Defaults to None.
silent (bool, optional): Whether to disable progressing. Defaults to False.
+
+ Returns:
+ dict: {"texts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
"""
- info = parse_msg(message, silent=True, use_cache=False)
- # parse config from environment variable
- genconfig = {}
- with contextlib.suppress(Exception):
- extra_config_str = GEMINI.IMG_CONFIG
- genconfig = json.loads(extra_config_str)
- try:
- real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
- msg = f"🤖**{GEMINI.IMG_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
- if not silent and kwargs.get("show_progress"):
- kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
- genconfig |= {"response_modalities": ["TEXT", "IMAGE"]}
- if system_prompt is not None:
- genconfig |= {"system_instruction": system_prompt}
- if GEMINI.IMG_THINKING_BUDGET is not None and not disable_thinking:
- thinking_budget = min(round(float(GEMINI.IMG_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
- params = {"model": GEMINI.IMG_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
- logger.trace(params)
- return await gemini_non_stream(client, message, GEMINI.IMG_MODEL_NAME, params, **kwargs)
- except Exception as e:
- logger.error(e)
+ model_name = model_id.split("/")[-1].title()
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=f"🍌**{model_name}**:\n{prompt}", **kwargs))[0]
+
+ for api_key in strings_list(TEXT2IMG.GEMINI_API_KEY, shuffle=True):
+ try:
+ http_options = HttpOptions(base_url=TEXT2IMG.GEMINI_BASE_URL, async_client_args={"proxy": TEXT2IMG.GEMINI_PROXY})
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ contents = await gen_prompts(client, message)
+ logger.trace(contents)
+ response = await app.aio.models.generate_content(
+ model=model_id,
+ contents=contents,
+ config=GenerateContentConfig(response_modalities=["IMAGE"]),
+ )
+ logger.trace(response)
+ await app.aio.aclose()
+ caption = ""
+ media = []
+ for part in flatten(glom(response, "candidates.*.content.parts", default=[])):
+ if part.text:
+ caption += part.text
+ elif image := part.as_image():
+ ext = part.inline_data.mime_type.split("/")[-1]
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_number()}.{ext}"
+ image.save(save_path)
+ media.append({"photo": save_path})
+ logger.success(f"🍌{model_name}: {caption}")
+ if media:
+ sent_message = await send2tg(client, message, caption_above=True, texts=f"🍌**{model_name}**:", media=media, **kwargs)
+ await modify_progress(del_status=True, **kwargs)
+ return {
+ "prefix": f"🍌**{model_name}**:",
+ "model_name": model_name,
+ "texts": caption,
+ "sent_message": sent_message,
+ }
+ except Exception as e:
+ logger.error(e)
+ await modify_progress(del_status=True, **kwargs)
return {}
-async def gemini_non_stream(
- client: Client,
- message: Message,
- model_name: str,
- params: dict,
- retry: int = 0,
- **kwargs,
-) -> dict:
- """Gemini non-stream response.
-
- Returns:
- dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
- """
- results = {}
- try:
- api_keys = strings_list(kwargs.get("gemini_api_keys", GEMINI.API_KEY))
- if retry > len(api_keys) - 1:
- return {}
- api_key = kwargs.get("gemini_api_key", api_keys[retry])
- http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
- # Construct the request params
- if "conversations" in params: # convert conversations to contents
- params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), model_id=params["model"], ctx_format="gemini", app=app)
- clean_gemini_sourcemarks(params["contents"])
- genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
- response = await app.aio.models.generate_content(**genai_params)
- await app.aio.aclose()
- prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- res = parse_response(response.model_dump())
- texts = res.get("texts", "")
- results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": ""}
- media = res.get("media", [])
- total = prefix + texts.strip()
- length = await count_without_entities(total)
- single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
- if length <= GPT.COLLAPSE_LENGTH:
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
- elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
- final = prefix + blockquote(texts.strip())
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
- else: # multiple messages
- for idx, txt in await smart_split(total, single_msg_length):
- if idx == 0:
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
- else:
- results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
- await modify_progress(del_status=True, **kwargs)
- except Exception as e:
- logger.error(e)
- error = str(e)
- if "res" in locals():
- error += f"\n{res}" # type: ignore
- if "response" in locals():
- error += f"\n{response}" # type: ignore
- await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_non_stream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
- return results
+async def gen_prompts(client: Client, message: Message) -> ContentListUnion:
+ """Generate prompts."""
+ prompts = []
+ for msg in get_conversations(message): # old to new
+ messages = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
+ role = "model" if any(m.content.startswith(f"🍌{TEXT2IMG.GEMINI_MODEL.title()}") for m in messages) else "user"
+ parts = []
+ for m in messages:
+ m = command_alias(m) # noqa: PLW2901
+ try:
+ if m.photo:
+ buffer: BytesIO = await m.download(in_memory=True) # type: ignore
+ ext = Path(buffer.name).suffix.removeprefix(".").replace("jpg", "jpeg")
+ parts.append(Part.from_bytes(data=buffer.getvalue(), mime_type=f"image/{ext}"))
+ if role == "user" and m.content:
+ text = remove_prefix(m.content, PREFIX.GENIMG)
+ text = remove_prefix(text, "@gemini")
+ parts.append(Part.from_text(text=text))
+ except Exception as e:
+ logger.error(e)
+ prompts.append({"role": role, "parts": parts})
+ return prompts
src/llm/models.py
@@ -184,8 +184,8 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
model_id = GPT.KIMI_MODEL
elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
model_id = GPT.GROK_MODEL
- elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
- model_id = GEMINI.IMG_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🍌{TEXT2IMG.GEMINI_MODEL.title()}") and "gemini" in img_providers:
+ model_id = TEXT2IMG.GEMINI_MODEL
resp_modality = "image"
elif matched := re.match(rf"^🤖(.*?):{BOT_TIPS}", minfo["reply_text"]):
return matched.group(1).lower(), True, "text"
src/llm/text2img.py
@@ -2,11 +2,12 @@
# -*- coding: utf-8 -*-
from collections import defaultdict
+from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import GEMINI, PREFIX, TEXT2IMG
+from config import PREFIX, TEXT2IMG
from llm.ali.text2img import ali_text2img
from llm.cloudflare.text2img import cloudflare_text2img
from llm.doubao.text2img import doubao_genimg
@@ -18,11 +19,11 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
`{PREFIX.GENIMG}` 后接提示词即可生成
⚙️模型配置:
-- `{PREFIX.GENIMG}`: {TEXT2IMG.DEFAULT_MODEL}模型
-- `{PREFIX.GENIMG} @doubao`: Seedream模型
-- `{PREFIX.GENIMG} @flux`: Flux模型
-- `{PREFIX.GENIMG} @sd`: Stable Diffusion模型
-- `{PREFIX.GENIMG} @gemini`: Gemini模型
+- `{PREFIX.GENIMG}`: 默认模型 ({TEXT2IMG.DEFAULT_MODEL})
+- `/sd`: 豆包Seedream模型
+- `/flux`: Flux模型
+- `/stable`: Stable Diffusion模型
+- `/nano`: Gemini Nano Banana
上下文说明:
- Gemini模型会把整个回复消息链上的所有消息加入上下文
@@ -30,6 +31,9 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
- 其余模型不会把历史消息加入上下文
"""
+# /sd, /flux, /stable, /nano等命令是通过"别名"功能实现的 (src/others/alias.py)
+# 完整调用方式为 `/gen @doubao`, `/gen @flux`, ...
+
async def text2img(client: Client, message: Message, **kwargs) -> dict:
"""Text to image generation.
@@ -49,6 +53,8 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
categories = list(all_models) # ['gemini', 'flux', 'sd']
models = all_models.get(TEXT2IMG.DEFAULT_MODEL, [])
prompt = texts
+ if glom(message, "reply_to_message.content", default="").startswith(f"🍌{TEXT2IMG.GEMINI_MODEL.title()}"):
+ models = all_models.get("gemini", [])
for category in categories:
if texts.lower().startswith(f"@{category}"):
models = all_models[category]
@@ -59,7 +65,7 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
provider, model_id = model.split("/", 1)
try:
if provider == "gemini":
- return await gemini_text2img(client, message, **kwargs)
+ return await gemini_text2img(client, message, model_id, prompt, **kwargs)
if provider == "ali":
return await ali_text2img(client, message, model_id, prompt, **kwargs)
if provider == "cloudflare":
@@ -86,7 +92,7 @@ def enabled_models() -> dict[str, list]:
_, img_providers = enabled_providers()
for provider in img_providers:
if provider == "gemini":
- models["gemini"] = [f"gemini/{GEMINI.IMG_MODEL}"]
+ models["gemini"] = [f"gemini/{TEXT2IMG.GEMINI_MODEL}"]
if provider == "ali" and TEXT2IMG.ALI_FLUX_MODEL and "ali" in strings_list(TEXT2IMG.FLUX_PROVIDER):
models["flux"].extend([f"ali/{model}" for model in strings_list(TEXT2IMG.ALI_FLUX_MODEL)])
if provider == "ali" and TEXT2IMG.ALI_STABLE_DIFFUSION_MODEL and "ali" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
src/llm/utils.py
@@ -45,7 +45,7 @@ def enabled_providers() -> tuple[list[str], list[str]]:
text_providers.append("gemini")
img_providers = []
- if all([GEMINI.API_KEY, GEMINI.BASE_URL, GEMINI.IMG_MODEL, GEMINI.IMG_MODEL_NAME]):
+ if all([TEXT2IMG.GEMINI_API_KEY, TEXT2IMG.GEMINI_BASE_URL, TEXT2IMG.GEMINI_MODEL]):
img_providers.append("gemini")
if all([TEXT2IMG.ALI_API_KEY]):
img_providers.append("ali")
src/config.py
@@ -469,12 +469,6 @@ class GEMINI: # Official Gemini
TEXT_MAX_TOKEN = int(os.getenv("GEMINI_TEXT_MAX_TOKEN", "250000")) # 250K
TEXT_TOKENS_FALLBACK_MODEL = os.getenv("GEMINI_TEXT_TOKENS_FALLBACK_MODEL", "gemini-2.0-flash") # model id when the token count is larger than GEMINI.TEXT_MAX_TOKEN
- # response modality: image
- IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")
- IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
- IMG_THINKING_BUDGET = os.getenv("GEMINI_IMG_THINKING_BUDGET", None) # 0 to disable thinking. DO NOT set this if the model is not a thinking model
- IMG_CONFIG = os.getenv("GEMINI_IMG_CONFIG", "{}") # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
-
# ASR related
ASR_MAX_DURATION = int(os.getenv("GEMINI_ASR_MAX_DURATION", "34200")) # 9.5 hour
ASR_MODEL = os.getenv("GEMINI_ASR_MODEL", "gemini-2.5-flash")
@@ -498,3 +492,7 @@ class TEXT2IMG:
DOUBAO_API_KEY = os.getenv("TEXT2IMG_DOUBAO_API_KEY", "") # comma separated keys
DOUBAO_SEEDREAM_MODEL = os.getenv("TEXT2IMG_DOUBAO_SEEDREAM_MODEL", "doubao-seedream-4-0-250828")
DOUBAO_PROXY = os.getenv("TEXT2IMG_DOUBAO_PROXY", None)
+ GEMINI_MODEL = os.getenv("TEXT2IMG_GEMINI_MODEL", "gemini-2.5-flash-image")
+ GEMINI_BASE_URL = os.getenv("TEXT2IMG_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com")
+ GEMINI_PROXY = os.getenv("TEXT2IMG_GEMINI_PROXY", None)
+ GEMINI_API_KEY = os.getenv("TEXT2IMG_GEMINI_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"