Commit d519f14

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-28 10:03:11
refactor(gemini): rename AIGC to GEMINI
1 parent d385bc8
src/llm/gemini.py
@@ -2,7 +2,6 @@
 # -*- coding: utf-8 -*-
 
 import contextlib
-import random
 from io import BytesIO
 from pathlib import Path
 
@@ -14,7 +13,7 @@ from PIL import Image
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import AIGC, DOWNLOAD_DIR, PREFIX, TEXT_LENGTH
+from config import DOWNLOAD_DIR, GEMINI, PREFIX, TEXT_LENGTH
 from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -22,18 +21,18 @@ from messages.sender import send2tg
 from messages.utils import count_without_entities, smart_split
 from utils import number_to_emoji, rand_string
 
-HELP = f"""🌠**AIGC**
+HELP = f"""🌠**AI生图**
 `{PREFIX.GENIMG}` 后接提示词即可生成
 回复消息可继续对话重新修改生成结果
 
 ⚙️模型配置:
-🏞生图模型: **{AIGC.IMG_MODEL}
+🌠生图模型: **{GEMINI.IMG_MODEL}
 
 ⚠️目前只支持生成图片
 """
 
 
-async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], model: str = "", model_name: str = "", modality: str = "image", **kwargs):
+async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], modality: str = "image", **kwargs):
     r"""Get Gemini response.
 
     gpt_contexts: [
@@ -54,28 +53,31 @@ async def gemini_response(client: Client, message: Message, gpt_contexts: list[d
         model_name (str): friendly model name
         modality (str): response modality
     """
-    # ruff: noqa: RET502, RET503
     info = parse_msg(message)
-    api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+    model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
+    model_name = GEMINI.TEXT_MODEL_NAME if modality == "text" else GEMINI.IMG_MODEL_NAME
     response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
     tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
     keep_marks = modality == "text"  # keep source marks for text response
-
     try:
-        app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
-        count_tokens = await app.aio.models.count_tokens(model=model, contents=info["text"])
-        num_token = count_tokens.total_tokens or 0
-        if modality == "image" and num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
-            await send2tg(client, message, texts=f"生成{modality.upper()}时提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}\n当前提示词: {num_token} Tokens", **kwargs)
-            return
-        msg = f"🌠**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
+        msg = f"🤖**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
         contexts = [openai_context_to_gemini(context, keep_marks=keep_marks) for context in gpt_contexts]
         gemini_logging(contexts)
+        params = {}
+        params |= {"model": model, "contents": contexts}
+        genconfig = {}
+        genconfig |= {"response_modalities": response_modalities}
+        if tools:
+            genconfig |= {"tools": tools}
+        if GEMINI.PREFER_LANG and modality == "text":
+            genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}回复"}
+        params |= {"config": GenerateContentConfig(**genconfig)}
+
         if modality == "image":
-            return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
-        return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+            return await gemini_nonstream(client, message, model_name, params, **kwargs)
+        return await gemini_stream(client, message, model_name, params, **kwargs)
     except Exception as e:
         logger.error(e)
 
@@ -141,28 +143,19 @@ def gemini_logging(contexts: list):
 async def gemini_nonstream(
     client: Client,
     message: Message,
-    contexts: list[ContentUnionDict],
-    model: str,
     model_name: str,
-    response_modalities: list[str],
-    tools: list | None = None,
+    params: dict,
     retry: int = 0,
     **kwargs,
 ):
     try:
-        api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+        api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
         if retry > len(api_keys) - 1:
-            return
-        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
-        response = await app.aio.models.generate_content(
-            model=model,
-            contents=contexts,
-            config=GenerateContentConfig(
-                response_modalities=response_modalities,
-                tools=tools,
-            ),
-        )
-        prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+            return None
+        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY}))
+
+        response = await app.aio.models.generate_content(**params)
+        prefix = f"🤖**{model_name}**: ({BOT_TIPS})\n"
         res = parse_response(response.model_dump(), prefix=prefix)
         await send2tg(client, message, caption_above=True, **res, **kwargs)
         await modify_progress(del_status=True, **kwargs)
@@ -174,7 +167,7 @@ async def gemini_nonstream(
         if "response" in locals():
             error += f"\n{response}"
         await modify_progress(text=error, force_update=True, **kwargs)
-        return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs)  # type: ignore
+        return await gemini_nonstream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
 
 
 def parse_response(data: dict, prefix: str = "") -> dict:
@@ -206,27 +199,20 @@ def parse_response(data: dict, prefix: str = "") -> dict:
 async def gemini_stream(
     client: Client,
     message: Message,
-    contexts: list[ContentUnionDict],
-    model: str,
     model_name: str,
-    response_modalities: list[str],
-    tools: list | None = None,
+    params: dict,
     retry: int = 0,
     **kwargs,
 ):
-    prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+    prefix = f"🤖**{model_name}**: ({BOT_TIPS})\n"
     answers = prefix
     try:
         status = kwargs.get("progress")
-        api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+        api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
         if retry > len(api_keys) - 1:
-            return
-        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
-        async for chunk in await app.aio.models.generate_content_stream(
-            model=model,
-            contents=contexts,
-            config=GenerateContentConfig(response_modalities=response_modalities, tools=tools),
-        ):
+            return None
+        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY}))
+        async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump())
             answer = resp.get("texts", "")
             answers += answer
@@ -249,4 +235,4 @@ async def gemini_stream(
         if "resp" in locals():
             error += f"\n{resp}"
         await modify_progress(text=error, force_update=True, **kwargs)
-        return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs)  # type: ignore
+        return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
src/llm/gpt.py
@@ -5,7 +5,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
+from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.gemini import HELP as AIGC_HELP
 from llm.gemini import gemini_response
@@ -29,7 +29,7 @@ HELP = f"""🤖**GPT对话**
 
 🔄使用以下命令强制切换模型:
 `/gpt`: **{GPT.OPENAI_MODEL_NAME}** {image_emoji(GPT.OPENAI_IMAGE_CAPABILITY)}
-`/gemini`: **{GPT.GEMINI_MODEL_NAME}** {image_emoji(GPT.GEMINI_IMAGE_CAPABILITY)}
+`/gemini`: **{GEMINI.TEXT_MODEL_NAME}** {image_emoji(capability=True)}
 `/ds`: **{GPT.DEEPSEEK_MODEL_NAME}** {image_emoji(GPT.DEEPSEEK_IMAGE_CAPABILITY)}
 `/qwen`: **{GPT.QWEN_MODEL_NAME}** {image_emoji(GPT.QWEN_IMAGE_CAPABILITY)}
 `/doubao`: **{GPT.DOUBAO_MODEL_NAME}** {image_emoji(GPT.DOUBAO_IMAGE_CAPABILITY)}
@@ -55,16 +55,16 @@ def is_gpt_conversation(message: Message) -> bool:
     reply_info = parse_msg(reply_msg, silent=True)
     model_names = [
         GPT.OPENAI_MODEL_NAME,
-        GPT.GEMINI_MODEL_NAME,
         GPT.DEEPSEEK_MODEL_NAME,
         GPT.QWEN_MODEL_NAME,
         GPT.DOUBAO_MODEL_NAME,
         GPT.GROK_MODEL_NAME,
         GPT.TEXT_MODEL_NAME,
         GPT.IMAGE_MODEL_NAME,
-        AIGC.IMG_MODEL_NAME,
+        GEMINI.TEXT_MODEL_NAME,
+        GEMINI.IMG_MODEL_NAME,
     ]
-    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in model_names])
+    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
 
 
 async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -92,7 +92,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
         reply_text = reply_info["text"]
 
     force_model, modality = parse_force_model(info["text"], reply_text)
-
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -103,12 +102,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     context_type = get_context_type(conversations)
     contexts = await get_conversation_contexts(client, conversations)
     config = get_gpt_config(context_type["type"], contexts, force_model)
+    if any("gemini" in x.lower() for x in [config["completions"]["model"], config["friendly_name"]]):
+        return await gemini_response(client, message, contexts, modality, **kwargs)
     if not config["client"]["api_key"]:
         logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
         return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    if "gemini" in config["completions"]["model"].lower():
-        model_name = AIGC.IMG_MODEL_NAME if startswith_prefix(info["text"], prefix=[PREFIX.GENIMG]) else config["friendly_name"]
-        return await gemini_response(client, message, contexts, config["completions"]["model"], model_name, modality, **kwargs)
 
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
src/llm/hooks.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-from config import GPT
+from config import GEMINI, GPT
 from llm.prompts import modify_prompts, refine_prompts
 from utils import unicode_to_ascii
 
@@ -8,9 +8,8 @@ from utils import unicode_to_ascii
 def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
     pre_openrouter_hook(client, completions)
     pre_helicone_hook(client, message_info)
-    # Gemini tends to respond in English, even when the user's query is in another language.
-    if GPT.GEMINI_PREFER_LANG and "gemini" in completions["model"].lower():
-        modify_prompts(completions["messages"], prompt=f"请使用{GPT.GEMINI_PREFER_LANG}回复。", role="system", method="append")
+    if GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
+        modify_prompts(completions["messages"], prompt=f"请使用{GEMINI.PREFER_LANG}回复。", role="system", method="append")
     completions["messages"] = refine_prompts(completions["messages"])
 
 
src/llm/models.py
@@ -4,7 +4,7 @@
 from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
-from config import AIGC, GPT, PREFIX, PROXY
+from config import GEMINI, GPT, PREFIX, PROXY
 from messages.parser import parse_msg
 from messages.utils import startswith_prefix
 
@@ -37,8 +37,6 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
     # parse from bot reply
     if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
         force_model = GPT.OPENAI_MODEL
-    elif reply_text.startswith(f"🤖{GPT.GEMINI_MODEL_NAME}"):
-        force_model = GPT.GEMINI_MODEL
     elif reply_text.startswith(f"🤖{GPT.DEEPSEEK_MODEL_NAME}"):
         force_model = GPT.DEEPSEEK_MODEL
     elif reply_text.startswith(f"🤖{GPT.QWEN_MODEL_NAME}"):
@@ -47,14 +45,12 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
         force_model = GPT.DOUBAO_MODEL
     elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
         force_model = GPT.GROK_MODEL
-    elif reply_text.startswith(f"🌠{AIGC.IMG_MODEL_NAME}"):
-        force_model = AIGC.IMG_MODEL
+    elif reply_text.startswith(f"🤖{GEMINI.IMG_MODEL_NAME}"):
+        force_model = GEMINI.IMG_MODEL
         modality = "image"
     # parse from command prefix
     if startswith_prefix(text, prefix=["/gpt"]):
         force_model = GPT.OPENAI_MODEL
-    elif startswith_prefix(text, prefix=["/gemini"]):
-        force_model = GPT.GEMINI_MODEL
     elif startswith_prefix(text, prefix=["/ds"]):
         force_model = GPT.DEEPSEEK_MODEL
     elif startswith_prefix(text, prefix=["/qwen"]):
@@ -64,8 +60,11 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
     elif startswith_prefix(text, prefix=["/grok"]):
         force_model = GPT.GROK_MODEL
     elif startswith_prefix(text, prefix=[PREFIX.GENIMG]):
-        force_model = AIGC.IMG_MODEL
+        force_model = GEMINI.IMG_MODEL
         modality = "image"
+    elif startswith_prefix(text, prefix=["/gemini"]):
+        force_model = GEMINI.TEXT_MODEL
+        modality = "text"
     return force_model, modality
 
 
@@ -103,7 +102,6 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
     # align with force model
     model_factory = {
         GPT.OPENAI_MODEL: {"api_key": GPT.OPENAI_API_KEY, "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
-        GPT.GEMINI_MODEL: {"api_key": GPT.GEMINI_API_KEY, "base_url": GPT.GEMINI_BASE_URL, "model_name": GPT.GEMINI_MODEL_NAME},
         GPT.DEEPSEEK_MODEL: {"api_key": GPT.DEEPSEEK_API_KEY, "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
         GPT.QWEN_MODEL: {"api_key": GPT.QWEN_API_KEY, "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
         GPT.DOUBAO_MODEL: {"api_key": GPT.DOUBAO_API_KEY, "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
@@ -119,7 +117,6 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
         model_type == "image"  # check capabilities
         and (
             (force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
-            or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
             or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
             or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
             or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
src/config.py
@@ -142,7 +142,7 @@ class COOKIE:  # See: https://github.com/easychen/CookieCloud
 
 
 class GPT:  # see `llm/README.md`
-    # See class AIGC for the AIGC configurations
+    # See class GEMINI for the GEMINI configurations
     STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
     IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
@@ -173,15 +173,6 @@ class GPT:  # see `llm/README.md`
     MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
     HELICONE_API_KEY = os.getenv("HELICONE_API_KEY", "")
 
-    # comma separated reasoning models, add system prompt to the models to ensure the output format.
-    REASONING_MODELS = os.getenv("GPT_REASONING_MODELS", "")  # deprecated, we do not need this anymore
-    # /gemini command
-    GEMINI_MODEL = os.getenv("GPT_GEMINI_MODEL", "gemini-2.0-flash")
-    GEMINI_MODEL_NAME = os.getenv("GPT_GEMINI_MODEL_NAME", "Gemini-2.0-Flash")
-    GEMINI_API_KEY = os.getenv("GPT_GEMINI_API_KEY", "")
-    GEMINI_BASE_URL = os.getenv("GPT_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
-    GEMINI_IMAGE_CAPABILITY = os.getenv("GPT_GEMINI_IMAGE_CAPABILITY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
-    GEMINI_PREFER_LANG = os.getenv("GPT_GEMINI_PREFER_LANG", "")  # Set a prefer response language for Gemini
     # /gpt command
     OPENAI_MODEL = os.getenv("GPT_OPENAI_MODEL", "gpt-4o")
     OPENAI_MODEL_NAME = os.getenv("GPT_OPENAI_MODEL_NAME", "GPT-4o")
@@ -257,11 +248,17 @@ class ASR:
     TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
 
 
-class AIGC:
+class GEMINI:  # Official Gemini
     # https://ai.google.dev/gemini-api/docs/image-generation
-    IMG_BASR_URL = os.getenv("AIGC_IMG_BASR_URL", "https://generativelanguage.googleapis.com/")
-    IMG_API_KEY = os.getenv("AIGC_IMG_API_KEY", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
-    IMG_MODEL = os.getenv("AIGC_IMG_MODEL", "gemini-2.0-flash-exp")
-    IMG_MODEL_NAME = os.getenv("AIGC_IMG_MODEL_NAME", "Gemini-2.0-Flash")
-    IMG_PROXY = os.getenv("AIGC_IMG_PROXY", None)
-    IMG_MAX_PROMPT_TOKEN = int(os.getenv("AIGC_IMG_MAX_PROMPT_TOKEN", "480"))
+    BASR_URL = os.getenv("GEMINI_BASR_URL", "https://generativelanguage.googleapis.com/")
+    API_KEYS = os.getenv("GEMINI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
+    PROXY = os.getenv("GEMINI_PROXY", None)
+    PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "")  # Set a prefer response language for Gemini
+
+    # response modality: text
+    TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-pro-exp-03-25")
+    TEXT_MODEL_NAME = os.getenv("GEMINI_TEXT_MODEL_NAME", "Gemini-2.5-Pro")
+
+    # response modality: image
+    IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")
+    IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
src/handler.py
@@ -306,7 +306,7 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
     if permission["ai"]:
         msg += f"\n🤖**AI对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok`"
         msg += f"\n📖**AI总结**: `{PREFIX.AI_SUMMARY}` 总结历史聊天记录"
-        msg += f"\n🌠**AIGC**: `{PREFIX.GENIMG}`"
+        msg += f"\n🌠**AI生图**: `{PREFIX.GENIMG}`"
     if permission["asr"]:
         msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
     if permission["audio"]: