Commit 53e5f44

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-11-28 09:55:44
feat(gemini): add support for Gemini nano banana model
1 parent bcbdd3e
Changed files (5)
src/llm/gemini/text2img.py
@@ -1,32 +1,33 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import contextlib
-import json
+from pathlib import Path
+from typing import TYPE_CHECKING
 
+from glom import flatten, glom
 from google import genai
-from google.genai import types
+from google.genai.types import ContentListUnion, GenerateContentConfig, HttpOptions, Part
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import CAPTION_LENGTH, GEMINI, GPT, TEXT_LENGTH
-from llm.contexts import get_conversation_contexts, get_conversations
-from llm.gemini.utils import parse_response
-from llm.hooks import hook_gemini_httpoptions
-from llm.utils import BOT_TIPS, clean_cmd_prefix, clean_gemini_sourcemarks
-from messages.parser import parse_msg
+from config import DOWNLOAD_DIR, PREFIX, TEXT2IMG
+from llm.contexts import get_conversations
 from messages.progress import modify_progress
 from messages.sender import send2tg
-from messages.utils import blockquote, count_without_entities, smart_split
-from utils import strings_list
+from messages.utils import remove_prefix
+from others.alias import command_alias
+from utils import rand_number, strings_list
+
+if TYPE_CHECKING:
+    from io import BytesIO
 
 
 async def gemini_text2img(
     client: Client,
     message: Message,
+    model_id: str,
+    prompt: str,
     *,
-    disable_thinking: bool = False,
-    system_prompt: str | None = None,
     silent: bool = False,
     **kwargs,
 ) -> dict:
@@ -35,92 +36,73 @@ async def gemini_text2img(
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
-        disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
-        include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
-        system_prompt (str | None, optional): System prompt. Defaults to None.
         silent (bool, optional): Whether to disable progressing. Defaults to False.
+
+    Returns:
+        dict: {"texts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
     """
-    info = parse_msg(message, silent=True, use_cache=False)
-    # parse config from environment variable
-    genconfig = {}
-    with contextlib.suppress(Exception):
-        extra_config_str = GEMINI.IMG_CONFIG
-        genconfig = json.loads(extra_config_str)
-    try:
-        real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
-        msg = f"🤖**{GEMINI.IMG_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
-        if not silent and kwargs.get("show_progress"):
-            kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
-        genconfig |= {"response_modalities": ["TEXT", "IMAGE"]}
-        if system_prompt is not None:
-            genconfig |= {"system_instruction": system_prompt}
-        if GEMINI.IMG_THINKING_BUDGET is not None and not disable_thinking:
-            thinking_budget = min(round(float(GEMINI.IMG_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
-            genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
-        params = {"model": GEMINI.IMG_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
-        logger.trace(params)
-        return await gemini_non_stream(client, message, GEMINI.IMG_MODEL_NAME, params, **kwargs)
-    except Exception as e:
-        logger.error(e)
+    model_name = model_id.split("/")[-1].title()
+    if not silent and kwargs.get("show_progress"):
+        kwargs["progress"] = (await send2tg(client, message, texts=f"🍌**{model_name}**:\n{prompt}", **kwargs))[0]
+
+    for api_key in strings_list(TEXT2IMG.GEMINI_API_KEY, shuffle=True):
+        try:
+            http_options = HttpOptions(base_url=TEXT2IMG.GEMINI_BASE_URL, async_client_args={"proxy": TEXT2IMG.GEMINI_PROXY})
+            app = genai.Client(api_key=api_key, http_options=http_options)
+            contents = await gen_prompts(client, message)
+            logger.trace(contents)
+            response = await app.aio.models.generate_content(
+                model=model_id,
+                contents=contents,
+                config=GenerateContentConfig(response_modalities=["IMAGE"]),
+            )
+            logger.trace(response)
+            await app.aio.aclose()
+            caption = ""
+            media = []
+            for part in flatten(glom(response, "candidates.*.content.parts", default=[])):
+                if part.text:
+                    caption += part.text
+                elif image := part.as_image():
+                    ext = part.inline_data.mime_type.split("/")[-1]
+                    save_path = Path(DOWNLOAD_DIR) / f"{rand_number()}.{ext}"
+                    image.save(save_path)
+                    media.append({"photo": save_path})
+            logger.success(f"🍌{model_name}: {caption}")
+            if media:
+                sent_message = await send2tg(client, message, caption_above=True, texts=f"🍌**{model_name}**:", media=media, **kwargs)
+                await modify_progress(del_status=True, **kwargs)
+                return {
+                    "prefix": f"🍌**{model_name}**:",
+                    "model_name": model_name,
+                    "texts": caption,
+                    "sent_message": sent_message,
+                }
+        except Exception as e:
+            logger.error(e)
+    await modify_progress(del_status=True, **kwargs)
     return {}
 
 
-async def gemini_non_stream(
-    client: Client,
-    message: Message,
-    model_name: str,
-    params: dict,
-    retry: int = 0,
-    **kwargs,
-) -> dict:
-    """Gemini non-stream response.
-
-    Returns:
-        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
-    """
-    results = {}
-    try:
-        api_keys = strings_list(kwargs.get("gemini_api_keys", GEMINI.API_KEY))
-        if retry > len(api_keys) - 1:
-            return {}
-        api_key = kwargs.get("gemini_api_key", api_keys[retry])
-        http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
-        http_options = hook_gemini_httpoptions(http_options, message)
-        app = genai.Client(api_key=api_key, http_options=http_options)
-        # Construct the request params
-        if "conversations" in params:  # convert conversations to contents
-            params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), model_id=params["model"], ctx_format="gemini", app=app)
-        clean_gemini_sourcemarks(params["contents"])
-        genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
-        response = await app.aio.models.generate_content(**genai_params)
-        await app.aio.aclose()
-        prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
-        res = parse_response(response.model_dump())
-        texts = res.get("texts", "")
-        results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": ""}
-        media = res.get("media", [])
-        total = prefix + texts.strip()
-        length = await count_without_entities(total)
-        single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
-        if length <= GPT.COLLAPSE_LENGTH:
-            results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
-        elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
-            final = prefix + blockquote(texts.strip())
-            results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
-        else:  # multiple messages
-            for idx, txt in await smart_split(total, single_msg_length):
-                if idx == 0:
-                    results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
-                else:
-                    results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
-        await modify_progress(del_status=True, **kwargs)
-    except Exception as e:
-        logger.error(e)
-        error = str(e)
-        if "res" in locals():
-            error += f"\n{res}"  # type: ignore
-        if "response" in locals():
-            error += f"\n{response}"  # type: ignore
-        await modify_progress(text=error, force_update=True, **kwargs)
-        return await gemini_non_stream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
-    return results
+async def gen_prompts(client: Client, message: Message) -> ContentListUnion:
+    """Generate prompts."""
+    prompts = []
+    for msg in get_conversations(message):  # old to new
+        messages = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
+        role = "model" if any(m.content.startswith(f"🍌{TEXT2IMG.GEMINI_MODEL.title()}") for m in messages) else "user"
+        parts = []
+        for m in messages:
+            m = command_alias(m)  # noqa: PLW2901
+            try:
+                if m.photo:
+                    buffer: BytesIO = await m.download(in_memory=True)  # type: ignore
+                    ext = Path(buffer.name).suffix.removeprefix(".").replace("jpg", "jpeg")
+                    parts.append(Part.from_bytes(data=buffer.getvalue(), mime_type=f"image/{ext}"))
+                if role == "user" and m.content:
+                    text = remove_prefix(m.content, PREFIX.GENIMG)
+                    text = remove_prefix(text, "@gemini")
+                    parts.append(Part.from_text(text=text))
+            except Exception as e:
+                logger.error(e)
+        prompts.append({"role": role, "parts": parts})
+    return prompts
src/llm/models.py
@@ -184,8 +184,8 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
         model_id = GPT.KIMI_MODEL
     elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
         model_id = GPT.GROK_MODEL
-    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
-        model_id = GEMINI.IMG_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🍌{TEXT2IMG.GEMINI_MODEL.title()}") and "gemini" in img_providers:
+        model_id = TEXT2IMG.GEMINI_MODEL
         resp_modality = "image"
     elif matched := re.match(rf"^🤖(.*?):{BOT_TIPS}", minfo["reply_text"]):
         return matched.group(1).lower(), True, "text"
src/llm/text2img.py
@@ -2,11 +2,12 @@
 # -*- coding: utf-8 -*-
 from collections import defaultdict
 
+from glom import glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import GEMINI, PREFIX, TEXT2IMG
+from config import PREFIX, TEXT2IMG
 from llm.ali.text2img import ali_text2img
 from llm.cloudflare.text2img import cloudflare_text2img
 from llm.doubao.text2img import doubao_genimg
@@ -18,11 +19,11 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
 `{PREFIX.GENIMG}` 后接提示词即可生成
 
 ⚙️模型配置:
-- `{PREFIX.GENIMG}`: {TEXT2IMG.DEFAULT_MODEL}模型
-- `{PREFIX.GENIMG} @doubao`: Seedream模型
-- `{PREFIX.GENIMG} @flux`: Flux模型
-- `{PREFIX.GENIMG} @sd`: Stable Diffusion模型
-- `{PREFIX.GENIMG} @gemini`: Gemini模型
+- `{PREFIX.GENIMG}`: 默认模型 ({TEXT2IMG.DEFAULT_MODEL})
+- `/sd`: 豆包Seedream模型
+- `/flux`: Flux模型
+- `/stable`: Stable Diffusion模型
+- `/nano`: Gemini Nano Banana
 
 上下文说明:
 - Gemini模型会把整个回复消息链上的所有消息加入上下文
@@ -30,6 +31,9 @@ TEXT2IMG_HELP = f"""🌠**AI生图**
 - 其余模型不会把历史消息加入上下文
 """
 
+# /sd, /flux, /stable, /nano等命令是通过"别名"功能实现的 (src/others/alias.py)
+# 完整调用方式为 `/gen @doubao`, `/gen @flux`, ...
+
 
 async def text2img(client: Client, message: Message, **kwargs) -> dict:
     """Text to image generation.
@@ -49,6 +53,8 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
     categories = list(all_models)  # ['gemini', 'flux', 'sd']
     models = all_models.get(TEXT2IMG.DEFAULT_MODEL, [])
     prompt = texts
+    if glom(message, "reply_to_message.content", default="").startswith(f"🍌{TEXT2IMG.GEMINI_MODEL.title()}"):
+        models = all_models.get("gemini", [])
     for category in categories:
         if texts.lower().startswith(f"@{category}"):
             models = all_models[category]
@@ -59,7 +65,7 @@ async def text2img(client: Client, message: Message, **kwargs) -> dict:
         provider, model_id = model.split("/", 1)
         try:
             if provider == "gemini":
-                return await gemini_text2img(client, message, **kwargs)
+                return await gemini_text2img(client, message, model_id, prompt, **kwargs)
             if provider == "ali":
                 return await ali_text2img(client, message, model_id, prompt, **kwargs)
             if provider == "cloudflare":
@@ -86,7 +92,7 @@ def enabled_models() -> dict[str, list]:
     _, img_providers = enabled_providers()
     for provider in img_providers:
         if provider == "gemini":
-            models["gemini"] = [f"gemini/{GEMINI.IMG_MODEL}"]
+            models["gemini"] = [f"gemini/{TEXT2IMG.GEMINI_MODEL}"]
         if provider == "ali" and TEXT2IMG.ALI_FLUX_MODEL and "ali" in strings_list(TEXT2IMG.FLUX_PROVIDER):
             models["flux"].extend([f"ali/{model}" for model in strings_list(TEXT2IMG.ALI_FLUX_MODEL)])
         if provider == "ali" and TEXT2IMG.ALI_STABLE_DIFFUSION_MODEL and "ali" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
src/llm/utils.py
@@ -45,7 +45,7 @@ def enabled_providers() -> tuple[list[str], list[str]]:
         text_providers.append("gemini")
 
     img_providers = []
-    if all([GEMINI.API_KEY, GEMINI.BASE_URL, GEMINI.IMG_MODEL, GEMINI.IMG_MODEL_NAME]):
+    if all([TEXT2IMG.GEMINI_API_KEY, TEXT2IMG.GEMINI_BASE_URL, TEXT2IMG.GEMINI_MODEL]):
         img_providers.append("gemini")
     if all([TEXT2IMG.ALI_API_KEY]):
         img_providers.append("ali")
src/config.py
@@ -469,12 +469,6 @@ class GEMINI:  # Official Gemini
     TEXT_MAX_TOKEN = int(os.getenv("GEMINI_TEXT_MAX_TOKEN", "250000"))  # 250K
     TEXT_TOKENS_FALLBACK_MODEL = os.getenv("GEMINI_TEXT_TOKENS_FALLBACK_MODEL", "gemini-2.0-flash")  # model id when the token count is larger than GEMINI.TEXT_MAX_TOKEN
 
-    # response modality: image
-    IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")
-    IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
-    IMG_THINKING_BUDGET = os.getenv("GEMINI_IMG_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
-    IMG_CONFIG = os.getenv("GEMINI_IMG_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
-
     # ASR related
     ASR_MAX_DURATION = int(os.getenv("GEMINI_ASR_MAX_DURATION", "34200"))  # 9.5 hour
     ASR_MODEL = os.getenv("GEMINI_ASR_MODEL", "gemini-2.5-flash")
@@ -498,3 +492,7 @@ class TEXT2IMG:
     DOUBAO_API_KEY = os.getenv("TEXT2IMG_DOUBAO_API_KEY", "")  # comma separated keys
     DOUBAO_SEEDREAM_MODEL = os.getenv("TEXT2IMG_DOUBAO_SEEDREAM_MODEL", "doubao-seedream-4-0-250828")
     DOUBAO_PROXY = os.getenv("TEXT2IMG_DOUBAO_PROXY", None)
+    GEMINI_MODEL = os.getenv("TEXT2IMG_GEMINI_MODEL", "gemini-2.5-flash-image")
+    GEMINI_BASE_URL = os.getenv("TEXT2IMG_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com")
+    GEMINI_PROXY = os.getenv("TEXT2IMG_GEMINI_PROXY", None)
+    GEMINI_API_KEY = os.getenv("TEXT2IMG_GEMINI_API_KEY", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"