Commit e2a026a

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-01-19 14:54:24
feat(ai): support multiple model configs for AI text generation
1 parent 2c201f7
src/ai/texts/gemini.py
@@ -30,7 +30,7 @@ async def gemini_chat_completion(
     gemini_base_url: str = AI.GEMINI_BASE_URL,
     gemini_api_keys: str = AI.GEMINI_API_KEYS,
     gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
-    gemini_generate_content_config: str | dict = AI.GEMINI_GENERATE_CONTENT_CONFIG,
+    gemini_generate_content_config: str | dict = "",
     gemini_proxy: str | None = PROXY.GOOGLE,
     gemini_append_grounding: bool = True,
     silent: bool = False,
src/ai/texts/models.py
@@ -2,7 +2,6 @@
 # -*- coding: utf-8 -*-
 import re
 
-from glom import glom
 from loguru import logger
 from pyrogram.types import Message
 
@@ -13,87 +12,88 @@ from messages.utils import startswith_prefix
 
 
 # ruff: noqa: RUF002
-async def get_text_model_config(message: Message) -> dict:
+async def get_text_model_configs(message: Message) -> list[dict]:
     r"""Get model config based on the message.
 
     Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
 
     A sample config:
     {
-    "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
-    "default": {
-        "model_id": "gemini-2.5-flash",
-        "model_name": "Gemini-2.5-Flash",
-        "api_type": "gemini",
-        "gemini_base_url": "https://generativelanguage.googleapis.com",
-        "gemini_api_keys": "key1,key2,key3,...",
-        "gemini_generate_content_config": {
-            "max_output_tokens": 65536,
-            "media_resolution": "MEDIA_RESOLUTION_HIGH",
-            "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
-            "tools":[{"google_search": {}}, {"url_context": {}}]}
-        }
-    },
-    "gpt": {
-        "model_id": "gpt-4o",
-        "model_name": "GPT-4o",
-        "api_type": "openai_chat",
-        "openai_base_url": "https://api.openai.com/v1",
-        "openai_api_keys": "key1,key2,key3",
-        "openai_completions_config": {
-            "temperature": 1.0,
-            "max_completion_tokens": 4096
-        }
-    },
-    "gpt-helicone": {
-        "model_id": "gpt-4o",
-        "model_name": "GPT-4o",
-        "api_type": "openai_chat",
-        "openai_base_url": "https://gateway.helicone.ai/v1",
-        "openai_api_keys": "key1,key2,key3,...",
-        "openai_default_headers": {
-            "helicone-auth": "Bearer HELICONE_API_KEY",
-            "helicone-target-url": "https://api.openai.com"
+    "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
+    "gemini": {
+        "common_config": {
+            "api_type": "gemini",
+            "gemini_base_url": "https://generativelanguage.googleapis.com",
+            "gemini_api_keys": "key1,key2,key3...",
         },
-        "openai_completions_config": {
-            "temperature": 1.0,
-            "max_completion_tokens": 4096
-        }
+        "models": [
+            {
+                "model_id": "gemini-3-flash-preview",
+                "model_name": "Gemini-3-Flash",
+                "gemini_generate_content_config": {
+                    "max_output_tokens": 65536,
+                    "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
+                    "tools": [{"url_context": {}}, {"code_execution": {}}]
+                }
+            },
+            {
+                "model_name": "Gemini-2.5-Flash",
+                "model_id": "gemini-2.5-flash-preview-09-2025",
+                "gemini_generate_content_config": {
+                    "max_output_tokens": 65536,
+                    "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
+                    "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
+                }
+            },
+        ]
     },
-    "doubao": {
-        "model_id": "doubao-seed-1-8-251228",
-        "model_name": "Doubao-Seed-1.8",
-        "api_type": "openai_responses",
-        "cache_response_ttl": 604800,
-        "openai_base_url": "https://ark.cn-beijing.volces.com/api/v3",
-        "openai_api_keys": "key1,key2,key3,...",
-        "openai_responses_config": {
-        "reasoning": { "effort": "high" },
-        "max_output_tokens": 65536,
-        "extra_body": {
-            "thinking": { "type": "enabled" }
+    "gpt": {
+        "common_config": {
+            "api_type": "gemini",
+            "gemini_base_url": "https://generativelanguage.googleapis.com",
+            "gemini_api_keys": "key1,key2,key3...",
         },
-        "tools": [
+        "models": [
             {
-            "type": "web_search",
-            "max_keyword": 5,
-            "limit": 20
+                "model_id": "gpt-4o",
+                "model_name": "GPT-4o",
+                "api_type": "openai_chat",
+                "openai_base_url": "https://api.openai.com/v1",
+                "openai_api_keys": "key1,key2,key3...",
+                "openai_completions_config": {
+                    "temperature": 1.0,
+                    "max_completion_tokens": 4096
+                }
+            },
+            {
+                "model_id": "gpt-5.2",
+                "model_name": "GPT-5.2",
+                "api_type": "openai_responses",
+                "cache_response_ttl": 86400,
+                "openai_base_url": "https://gateway.helicone.ai/v1",
+                "openai_api_keys": "key1,key2,key3,...",
+                "openai_default_headers": {
+                    "helicone-auth": "Bearer HELICONE_API_KEY",
+                    "helicone-target-url": "https://api.openai.com"
+                },
+                "openai_responses_config": {
+                    "reasoning": { "effort": "high" },
+                    "max_output_tokens": 4096,
+                    "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
+                }
             }
-        ],
-        "max_tool_calls": 10
-        }
-    },
-    "tool_call_model": {
-        "model_id": "gpt-4o-mini",
-        "model_name": "Web Search",
-        "api_type": "openai_chat"
-        "openai_base_url": "https://api.openai.com/v1",
-        "openai_api_keys": "key1,key2,key3",
-        "openai_completions_config": {
-            "temperature": 1.0,
-            "max_completion_tokens": 4096
-        }
+        ]
     }
+    "tool_call_model": {
+        "models": [
+            {
+                "model_id": "gpt-4o-mini",
+                "model_name": "Web Search",
+                "api_type": "openai_chat"
+                "openai_base_url": "https://api.openai.com/v1",
+                "openai_api_keys": "key1,key2,key3"
+            }
+        ]
     }
 
 
@@ -106,7 +106,7 @@ async def get_text_model_config(message: Message) -> dict:
         Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
 
     Returns:
-        {
+        [{
             "model_id": "gpt-4o",
             "model_name": "GPT-4o",
             "openai_api_type": "chat",
@@ -116,12 +116,12 @@ async def get_text_model_config(message: Message) -> dict:
             "openai_completions_config": {},
             "openai_responses_config": {},
             ....  # other fileds will also be passed to the function
-        }
+        }]
     """
     texts = str(message.content).strip()
     if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
         # DO NOT respond to AI responses to avoid potential infinitely loop
-        return {}
+        return []
 
     # this message starts with /ai
     if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
@@ -129,66 +129,73 @@ async def get_text_model_config(message: Message) -> dict:
         prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
         if not prompt and not message.reply_to_message:  # no prompt & no reply_msg
             await message.reply_text(text=await text_generation_docs(), quote=True)
-            return {}
+            return []
         if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts):  # match /ai @custom_model_id
-            model_id = matched.group(1).strip()
-            return await get_config_by_model_id(model_id)
-        return await get_config_by_model_id("default")
+            model_alias = matched.group(1).strip()
+            return await get_config_by_model_alias(model_alias)
+        return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
 
     # this message is not /ai, try to find model id from reply_message
     reply_msg = message.reply_to_message
     if not isinstance(reply_msg, Message):
-        return {}
+        return []
 
     if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
         model_name = matched.group(1).strip()
         return await get_config_by_model_name(model_name)
-    return {}
+    return []
 
 
-async def get_config_by_model_id(model_id: str, *, fallback_to_default: bool = True) -> dict:
-    """Get model config by model_id.
+async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
+    """Get model config by model_alias.
 
     Returns:
         model_config
     """
     kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
-    default_config = kv.get("default", {})
-    if not default_config:
-        logger.warning(f"CF-KV key `{AI.TEXT_MODEL_CONFIG_KEY}` does not has `default` field")
-        default_config = (
-            {
-                "model_id": AI.GEMINI_MODEL_ID,
-                "model_name": AI.GEMINI_MODEL_ID,
-                "api_type": "gemini",
-                "gemini_base_url": AI.GEMINI_BASE_URL,
-                "gemini_api_keys": AI.GEMINI_API_KEYS,
-            }
-            if AI.TEXT_DEFAULT_PROVIDER == "gemini"
-            else {
-                "model_id": AI.OPENAI_MODEL_ID,
-                "model_name": AI.OPENAI_MODEL_ID,
-                "openai_api_type": "chat",
-                "openai_base_url": AI.OPENAI_BASE_URL,
-                "openai_api_keys": AI.OPENAI_API_KEYS,
-            }
-        )
-    custom_config = kv.get(model_id, {})
-    if not custom_config:
-        if fallback_to_default:
-            logger.warning(f"Model `{model_id}` is not configured in KV, using default config")
-            return default_config
-        return {}
-    return default_config | custom_config
+
+    if config := kv.get(model_alias, {}):
+        common_config = config.get("common_config", {})
+        return [common_config | model_config for model_config in config.get("models", [])]
+
+    if not fallback_to_default:
+        return []
+
+    logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
+    return [
+        {
+            "model_id": AI.GEMINI_MODEL_ID,
+            "model_name": AI.GEMINI_MODEL_ID,
+            "api_type": "gemini",
+            "gemini_base_url": AI.GEMINI_BASE_URL,
+            "gemini_api_keys": AI.GEMINI_API_KEYS,
+        },
+        {
+            "model_id": AI.OPENAI_MODEL_ID,
+            "model_name": AI.OPENAI_MODEL_ID,
+            "api_type": "openai_chat",
+            "openai_base_url": AI.OPENAI_BASE_URL,
+            "openai_api_keys": AI.OPENAI_API_KEYS,
+        },
+    ]
 
 
-async def get_config_by_model_name(model_name: str) -> dict:
+async def get_config_by_model_name(model_name: str) -> list[dict]:
     """Get model config by model_name.
 
     Returns:
         model_config
     """
     kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
-    model_names = {glom(v, "model_name", default=""): k for k, v in kv.items()}
-    model_alias = model_names.get(model_name, model_name)
-    return await get_config_by_model_id(model_alias)
+    model_configs = []
+    for alias, config in kv.items():
+        if not isinstance(config, dict):
+            continue
+        if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
+            continue
+        common_config = config.get("common_config", {})
+        for model in config.get("models", []):
+            model_config = common_config | model
+            if model_config.get("model_name", "") == model_name:
+                model_configs.append(model_config)
+    return model_configs
src/ai/texts/openai_chat.py
@@ -27,9 +27,9 @@ async def openai_chat_completions(
     model_name: str = AI.OPENAI_MODEL_ID,
     openai_base_url: str = AI.OPENAI_BASE_URL,
     openai_api_keys: str = AI.OPENAI_API_KEYS,
-    openai_client_config: str | dict = AI.OPENAI_CLIENT_CONFIG,
-    openai_default_headers: str | dict = AI.OPENAI_DEFAULT_HEADERS,
-    openai_completions_config: str | dict = AI.OPENAI_COMPLETIONS_CONFIG,
+    openai_client_config: str | dict = "",
+    openai_default_headers: str | dict = "",
+    openai_completions_config: str | dict = "",
     openai_proxy: str | None = PROXY.OPENAI,
     openai_system_prompt: str = "",
     openai_contexts: list[dict] | None = None,
src/ai/texts/openai_response.py
@@ -30,9 +30,9 @@ async def openai_responses_api(
     model_name: str = AI.OPENAI_MODEL_ID,
     openai_base_url: str = AI.OPENAI_BASE_URL,
     openai_api_keys: str = AI.OPENAI_API_KEYS,
-    openai_client_config: str | dict = AI.OPENAI_CLIENT_CONFIG,
-    openai_default_headers: str | dict = AI.OPENAI_DEFAULT_HEADERS,
-    openai_responses_config: str | dict = AI.OPENAI_RESPONSES_CONFIG,
+    openai_client_config: str | dict = "",
+    openai_default_headers: str | dict = "",
+    openai_responses_config: str | dict = "",
     openai_proxy: str | None = PROXY.OPENAI,
     cache_response_ttl: int = 0,
     silent: bool = False,
src/ai/main.py
@@ -9,7 +9,7 @@ from ai.images.models import get_image_model_configs
 from ai.images.openai_img import openai_image_generation
 from ai.images.post import http_post_image_generation
 from ai.texts.gemini import gemini_chat_completion
-from ai.texts.models import get_config_by_model_id, get_text_model_config
+from ai.texts.models import get_config_by_model_alias, get_text_model_configs
 from ai.texts.openai_chat import openai_chat_completions
 from ai.texts.openai_response import openai_responses_api
 from ai.texts.tool_call import get_tool_call_results
@@ -21,30 +21,35 @@ from messages.sender import send2tg
 from messages.utils import startswith_prefix
 
 
-async def ai_text_generation(client: Client, message: Message, *, silent: bool = False, **kwargs) -> dict:
+async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
     texts = str(message.content).strip()
     this_msg = message
     prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
     prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
     if not prompt and message.reply_to_message:
         message = this_msg.reply_to_message
-    model_config = await get_text_model_config(this_msg)
-    if not model_config.get("model_id"):
+    model_configs = await get_text_model_configs(this_msg)
+    if not model_configs:
         return {}
-    silent = silent or model_config.get("silent", False)
-    params: dict = {"api_type": AI.TEXT_DEFAULT_PROVIDER} | model_config | kwargs | {"silent": silent}
-    if params["api_type"] == "gemini":
-        return await gemini_chat_completion(client, message, **params)
-    if params["api_type"] == "openai_responses":
-        return await openai_responses_api(client, message, **params)
-    if params["api_type"] == "openai_chat":
-        if params.get("openai_enable_tool_call", True):
-            tool_config = await get_config_by_model_id("tool_call_model", fallback_to_default=False)
-            if tool_config:
-                tool_params = params | tool_config
-                tool_results = await get_tool_call_results(client, message, **tool_params)
-                params |= tool_results
-        return await openai_chat_completions(client, message, **params)
+    for model_config in model_configs:
+        params: dict = model_config | kwargs
+        match model_config["api_type"]:
+            case "gemini":
+                if res := await gemini_chat_completion(client, message, **params):
+                    return res
+            case "openai_responses":
+                if res := await openai_responses_api(client, message, **params):
+                    return res
+            case "openai_chat":
+                if params.get("openai_enable_tool_call", True):
+                    tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
+                    for tool_config in tool_configs:
+                        tool_params = params | tool_config
+                        if tool_results := await get_tool_call_results(client, message, **tool_params):
+                            params |= tool_results
+                            break
+                if res := await openai_chat_completions(client, message, **params):
+                    return res
     return {}
 
 
src/config.py
@@ -381,20 +381,16 @@ class AI:
     # Text Generation
     MAX_CONTEXTS_NUM = int(os.getenv("AI_MAX_CONTEXTS_NUM", "30"))
     TEXT_MODEL_CONFIG_KEY = os.getenv("AI_MODEL_CONFIG_KEY", "AI-TEXT")  # model configuration key in CF-KV
-    TEXT_DEFAULT_PROVIDER = os.getenv("AI_TEXT_DEFAULT_PROVIDER", "gemini")
-    OPENAI_MODEL_ID = os.getenv("AI_OPENAI_MODEL_ID", "gpt-4o")
-    OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
-    OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
-    OPENAI_CLIENT_CONFIG = os.getenv("AI_OPENAI_CLIENT_CONFIG", "{}")  # client config passed to OpenAI API. Should be a json string: '{"key": "value"}'
-    OPENAI_DEFAULT_HEADERS = os.getenv("AI_OPENAI_DEFAULT_HEADERS", "{}")  # default headers passed to OpenAI API. Should be a json string: '{"key": "value"}'
-    OPENAI_COMPLETIONS_CONFIG = os.getenv("AI_OPENAI_COMPLETIONS_CONFIG", "{}")  # chat completions config. Should be a json string: '{"key": "value"}'
-    OPENAI_RESPONSES_CONFIG = os.getenv("AI_OPENAI_RESPONSES_CONFIG", "{}")  # response api config. Should be a json string: '{"key": "value"}'
+    TEXT_GENERATION_DEFAULT_MODEL = os.getenv("AI_TEXT_GENERATION_DEFAULT_MODEL", "gemini")
     GEMINI_MODEL_ID = os.getenv("AI_GEMINI_MODEL_ID", "gemini-2.5-flash")
     GEMINI_API_KEYS = os.getenv("AI_GEMINI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     GEMINI_BASE_URL = os.getenv("AI_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com")
     GEMINI_DEFAULT_HEADERS = os.getenv("AI_GEMINI_DEFAULT_HEADERS", "{}")  # default headers passed to Gemini API. Should be a json string: '{"key": "value"}'
-    GEMINI_GENERATE_CONTENT_CONFIG = os.getenv("AI_GEMINI_GENERATE_CONTENT_CONFIG", "{}")  # gemini generate_content config. Should be a json string: '{"key": "value"}'
     GEMINI_FILES_TTL = int(os.getenv("AI_GEMINI_FILES_TTL", "172800"))  # clean gemini files after 48 hours
+    OPENAI_MODEL_ID = os.getenv("AI_OPENAI_MODEL_ID", "gpt-4o")
+    OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
+    OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
+    TOOL_CALL_MODEL_ALIAS = os.getenv("AI_TOOL_CALL_MODEL_ALIAS", "tool-call")
     PODCAST_SUMMARY_MODEL_ALIAS = os.getenv("PODCAST_SUMMARY_MODEL_ALIAS", "podcast-summary")
     SUBTITLE_SUMMARY_MODEL_ALIAS = os.getenv("SUBTITLE_SUMMARY_MODEL_ALIAS", "subtitle-summary")
     CHAT_SUMMARY_MODEL_ALIAS = os.getenv("CHAT_SUMMARY_MODEL_ALIAS", "chat-summary")