Commit e2a026a
Changed files (6)
src/ai/texts/gemini.py
@@ -30,7 +30,7 @@ async def gemini_chat_completion(
gemini_base_url: str = AI.GEMINI_BASE_URL,
gemini_api_keys: str = AI.GEMINI_API_KEYS,
gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
- gemini_generate_content_config: str | dict = AI.GEMINI_GENERATE_CONTENT_CONFIG,
+ gemini_generate_content_config: str | dict = "",
gemini_proxy: str | None = PROXY.GOOGLE,
gemini_append_grounding: bool = True,
silent: bool = False,
src/ai/texts/models.py
@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import re
-from glom import glom
from loguru import logger
from pyrogram.types import Message
@@ -13,87 +12,88 @@ from messages.utils import startswith_prefix
# ruff: noqa: RUF002
-async def get_text_model_config(message: Message) -> dict:
+async def get_text_model_configs(message: Message) -> list[dict]:
r"""Get model config based on the message.
Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
A sample config:
{
- "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
- "default": {
- "model_id": "gemini-2.5-flash",
- "model_name": "Gemini-2.5-Flash",
- "api_type": "gemini",
- "gemini_base_url": "https://generativelanguage.googleapis.com",
- "gemini_api_keys": "key1,key2,key3,...",
- "gemini_generate_content_config": {
- "max_output_tokens": 65536,
- "media_resolution": "MEDIA_RESOLUTION_HIGH",
- "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
- "tools":[{"google_search": {}}, {"url_context": {}}]}
- }
- },
- "gpt": {
- "model_id": "gpt-4o",
- "model_name": "GPT-4o",
- "api_type": "openai_chat",
- "openai_base_url": "https://api.openai.com/v1",
- "openai_api_keys": "key1,key2,key3",
- "openai_completions_config": {
- "temperature": 1.0,
- "max_completion_tokens": 4096
- }
- },
- "gpt-helicone": {
- "model_id": "gpt-4o",
- "model_name": "GPT-4o",
- "api_type": "openai_chat",
- "openai_base_url": "https://gateway.helicone.ai/v1",
- "openai_api_keys": "key1,key2,key3,...",
- "openai_default_headers": {
- "helicone-auth": "Bearer HELICONE_API_KEY",
- "helicone-target-url": "https://api.openai.com"
+ "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
+ "gemini": {
+ "common_config": {
+ "api_type": "gemini",
+ "gemini_base_url": "https://generativelanguage.googleapis.com",
+ "gemini_api_keys": "key1,key2,key3...",
},
- "openai_completions_config": {
- "temperature": 1.0,
- "max_completion_tokens": 4096
- }
+ "models": [
+ {
+ "model_id": "gemini-3-flash-preview",
+ "model_name": "Gemini-3-Flash",
+ "gemini_generate_content_config": {
+ "max_output_tokens": 65536,
+ "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
+ "tools": [{"url_context": {}}, {"code_execution": {}}]
+ }
+ },
+ {
+ "model_name": "Gemini-2.5-Flash",
+ "model_id": "gemini-2.5-flash-preview-09-2025",
+ "gemini_generate_content_config": {
+ "max_output_tokens": 65536,
+ "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
+ "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
+ }
+ },
+ ]
},
- "doubao": {
- "model_id": "doubao-seed-1-8-251228",
- "model_name": "Doubao-Seed-1.8",
- "api_type": "openai_responses",
- "cache_response_ttl": 604800,
- "openai_base_url": "https://ark.cn-beijing.volces.com/api/v3",
- "openai_api_keys": "key1,key2,key3,...",
- "openai_responses_config": {
- "reasoning": { "effort": "high" },
- "max_output_tokens": 65536,
- "extra_body": {
- "thinking": { "type": "enabled" }
+ "gpt": {
+ "common_config": {
+ "api_type": "gemini",
+ "gemini_base_url": "https://generativelanguage.googleapis.com",
+ "gemini_api_keys": "key1,key2,key3...",
},
- "tools": [
+ "models": [
{
- "type": "web_search",
- "max_keyword": 5,
- "limit": 20
+ "model_id": "gpt-4o",
+ "model_name": "GPT-4o",
+ "api_type": "openai_chat",
+ "openai_base_url": "https://api.openai.com/v1",
+ "openai_api_keys": "key1,key2,key3...",
+ "openai_completions_config": {
+ "temperature": 1.0,
+ "max_completion_tokens": 4096
+ }
+ },
+ {
+ "model_id": "gpt-5.2",
+ "model_name": "GPT-5.2",
+ "api_type": "openai_responses",
+ "cache_response_ttl": 86400,
+ "openai_base_url": "https://gateway.helicone.ai/v1",
+ "openai_api_keys": "key1,key2,key3,...",
+ "openai_default_headers": {
+ "helicone-auth": "Bearer HELICONE_API_KEY",
+ "helicone-target-url": "https://api.openai.com"
+ },
+ "openai_responses_config": {
+ "reasoning": { "effort": "high" },
+ "max_output_tokens": 4096,
+ "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
+ }
}
- ],
- "max_tool_calls": 10
- }
- },
- "tool_call_model": {
- "model_id": "gpt-4o-mini",
- "model_name": "Web Search",
- "api_type": "openai_chat"
- "openai_base_url": "https://api.openai.com/v1",
- "openai_api_keys": "key1,key2,key3",
- "openai_completions_config": {
- "temperature": 1.0,
- "max_completion_tokens": 4096
- }
+ ]
}
+ "tool_call_model": {
+ "models": [
+ {
+ "model_id": "gpt-4o-mini",
+ "model_name": "Web Search",
+ "api_type": "openai_chat"
+ "openai_base_url": "https://api.openai.com/v1",
+ "openai_api_keys": "key1,key2,key3"
+ }
+ ]
}
@@ -106,7 +106,7 @@ async def get_text_model_config(message: Message) -> dict:
Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
Returns:
- {
+ [{
"model_id": "gpt-4o",
"model_name": "GPT-4o",
"openai_api_type": "chat",
@@ -116,12 +116,12 @@ async def get_text_model_config(message: Message) -> dict:
"openai_completions_config": {},
"openai_responses_config": {},
.... # other fileds will also be passed to the function
- }
+ }]
"""
texts = str(message.content).strip()
if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
# DO NOT respond to AI responses to avoid potential infinitely loop
- return {}
+ return []
# this message starts with /ai
if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
@@ -129,66 +129,73 @@ async def get_text_model_config(message: Message) -> dict:
prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
if not prompt and not message.reply_to_message: # no prompt & no reply_msg
await message.reply_text(text=await text_generation_docs(), quote=True)
- return {}
+ return []
if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /ai @custom_model_id
- model_id = matched.group(1).strip()
- return await get_config_by_model_id(model_id)
- return await get_config_by_model_id("default")
+ model_alias = matched.group(1).strip()
+ return await get_config_by_model_alias(model_alias)
+ return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
# this message is not /ai, try to find model id from reply_message
reply_msg = message.reply_to_message
if not isinstance(reply_msg, Message):
- return {}
+ return []
if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
model_name = matched.group(1).strip()
return await get_config_by_model_name(model_name)
- return {}
+ return []
-async def get_config_by_model_id(model_id: str, *, fallback_to_default: bool = True) -> dict:
- """Get model config by model_id.
+async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
+ """Get model config by model_alias.
Returns:
model_config
"""
kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
- default_config = kv.get("default", {})
- if not default_config:
- logger.warning(f"CF-KV key `{AI.TEXT_MODEL_CONFIG_KEY}` does not has `default` field")
- default_config = (
- {
- "model_id": AI.GEMINI_MODEL_ID,
- "model_name": AI.GEMINI_MODEL_ID,
- "api_type": "gemini",
- "gemini_base_url": AI.GEMINI_BASE_URL,
- "gemini_api_keys": AI.GEMINI_API_KEYS,
- }
- if AI.TEXT_DEFAULT_PROVIDER == "gemini"
- else {
- "model_id": AI.OPENAI_MODEL_ID,
- "model_name": AI.OPENAI_MODEL_ID,
- "openai_api_type": "chat",
- "openai_base_url": AI.OPENAI_BASE_URL,
- "openai_api_keys": AI.OPENAI_API_KEYS,
- }
- )
- custom_config = kv.get(model_id, {})
- if not custom_config:
- if fallback_to_default:
- logger.warning(f"Model `{model_id}` is not configured in KV, using default config")
- return default_config
- return {}
- return default_config | custom_config
+
+ if config := kv.get(model_alias, {}):
+ common_config = config.get("common_config", {})
+ return [common_config | model_config for model_config in config.get("models", [])]
+
+ if not fallback_to_default:
+ return []
+
+ logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
+ return [
+ {
+ "model_id": AI.GEMINI_MODEL_ID,
+ "model_name": AI.GEMINI_MODEL_ID,
+ "api_type": "gemini",
+ "gemini_base_url": AI.GEMINI_BASE_URL,
+ "gemini_api_keys": AI.GEMINI_API_KEYS,
+ },
+ {
+ "model_id": AI.OPENAI_MODEL_ID,
+ "model_name": AI.OPENAI_MODEL_ID,
+ "api_type": "openai_chat",
+ "openai_base_url": AI.OPENAI_BASE_URL,
+ "openai_api_keys": AI.OPENAI_API_KEYS,
+ },
+ ]
-async def get_config_by_model_name(model_name: str) -> dict:
+async def get_config_by_model_name(model_name: str) -> list[dict]:
"""Get model config by model_name.
Returns:
model_config
"""
kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
- model_names = {glom(v, "model_name", default=""): k for k, v in kv.items()}
- model_alias = model_names.get(model_name, model_name)
- return await get_config_by_model_id(model_alias)
+ model_configs = []
+ for alias, config in kv.items():
+ if not isinstance(config, dict):
+ continue
+ if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
+ continue
+ common_config = config.get("common_config", {})
+ for model in config.get("models", []):
+ model_config = common_config | model
+ if model_config.get("model_name", "") == model_name:
+ model_configs.append(model_config)
+ return model_configs
src/ai/texts/openai_chat.py
@@ -27,9 +27,9 @@ async def openai_chat_completions(
model_name: str = AI.OPENAI_MODEL_ID,
openai_base_url: str = AI.OPENAI_BASE_URL,
openai_api_keys: str = AI.OPENAI_API_KEYS,
- openai_client_config: str | dict = AI.OPENAI_CLIENT_CONFIG,
- openai_default_headers: str | dict = AI.OPENAI_DEFAULT_HEADERS,
- openai_completions_config: str | dict = AI.OPENAI_COMPLETIONS_CONFIG,
+ openai_client_config: str | dict = "",
+ openai_default_headers: str | dict = "",
+ openai_completions_config: str | dict = "",
openai_proxy: str | None = PROXY.OPENAI,
openai_system_prompt: str = "",
openai_contexts: list[dict] | None = None,
src/ai/texts/openai_response.py
@@ -30,9 +30,9 @@ async def openai_responses_api(
model_name: str = AI.OPENAI_MODEL_ID,
openai_base_url: str = AI.OPENAI_BASE_URL,
openai_api_keys: str = AI.OPENAI_API_KEYS,
- openai_client_config: str | dict = AI.OPENAI_CLIENT_CONFIG,
- openai_default_headers: str | dict = AI.OPENAI_DEFAULT_HEADERS,
- openai_responses_config: str | dict = AI.OPENAI_RESPONSES_CONFIG,
+ openai_client_config: str | dict = "",
+ openai_default_headers: str | dict = "",
+ openai_responses_config: str | dict = "",
openai_proxy: str | None = PROXY.OPENAI,
cache_response_ttl: int = 0,
silent: bool = False,
src/ai/main.py
@@ -9,7 +9,7 @@ from ai.images.models import get_image_model_configs
from ai.images.openai_img import openai_image_generation
from ai.images.post import http_post_image_generation
from ai.texts.gemini import gemini_chat_completion
-from ai.texts.models import get_config_by_model_id, get_text_model_config
+from ai.texts.models import get_config_by_model_alias, get_text_model_configs
from ai.texts.openai_chat import openai_chat_completions
from ai.texts.openai_response import openai_responses_api
from ai.texts.tool_call import get_tool_call_results
@@ -21,30 +21,35 @@ from messages.sender import send2tg
from messages.utils import startswith_prefix
-async def ai_text_generation(client: Client, message: Message, *, silent: bool = False, **kwargs) -> dict:
+async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
texts = str(message.content).strip()
this_msg = message
prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
if not prompt and message.reply_to_message:
message = this_msg.reply_to_message
- model_config = await get_text_model_config(this_msg)
- if not model_config.get("model_id"):
+ model_configs = await get_text_model_configs(this_msg)
+ if not model_configs:
return {}
- silent = silent or model_config.get("silent", False)
- params: dict = {"api_type": AI.TEXT_DEFAULT_PROVIDER} | model_config | kwargs | {"silent": silent}
- if params["api_type"] == "gemini":
- return await gemini_chat_completion(client, message, **params)
- if params["api_type"] == "openai_responses":
- return await openai_responses_api(client, message, **params)
- if params["api_type"] == "openai_chat":
- if params.get("openai_enable_tool_call", True):
- tool_config = await get_config_by_model_id("tool_call_model", fallback_to_default=False)
- if tool_config:
- tool_params = params | tool_config
- tool_results = await get_tool_call_results(client, message, **tool_params)
- params |= tool_results
- return await openai_chat_completions(client, message, **params)
+ for model_config in model_configs:
+ params: dict = model_config | kwargs
+ match model_config["api_type"]:
+ case "gemini":
+ if res := await gemini_chat_completion(client, message, **params):
+ return res
+ case "openai_responses":
+ if res := await openai_responses_api(client, message, **params):
+ return res
+ case "openai_chat":
+ if params.get("openai_enable_tool_call", True):
+ tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
+ for tool_config in tool_configs:
+ tool_params = params | tool_config
+ if tool_results := await get_tool_call_results(client, message, **tool_params):
+ params |= tool_results
+ break
+ if res := await openai_chat_completions(client, message, **params):
+ return res
return {}
src/config.py
@@ -381,20 +381,16 @@ class AI:
# Text Generation
MAX_CONTEXTS_NUM = int(os.getenv("AI_MAX_CONTEXTS_NUM", "30"))
TEXT_MODEL_CONFIG_KEY = os.getenv("AI_MODEL_CONFIG_KEY", "AI-TEXT") # model configuration key in CF-KV
- TEXT_DEFAULT_PROVIDER = os.getenv("AI_TEXT_DEFAULT_PROVIDER", "gemini")
- OPENAI_MODEL_ID = os.getenv("AI_OPENAI_MODEL_ID", "gpt-4o")
- OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
- OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
- OPENAI_CLIENT_CONFIG = os.getenv("AI_OPENAI_CLIENT_CONFIG", "{}") # client config passed to OpenAI API. Should be a json string: '{"key": "value"}'
- OPENAI_DEFAULT_HEADERS = os.getenv("AI_OPENAI_DEFAULT_HEADERS", "{}") # default headers passed to OpenAI API. Should be a json string: '{"key": "value"}'
- OPENAI_COMPLETIONS_CONFIG = os.getenv("AI_OPENAI_COMPLETIONS_CONFIG", "{}") # chat completions config. Should be a json string: '{"key": "value"}'
- OPENAI_RESPONSES_CONFIG = os.getenv("AI_OPENAI_RESPONSES_CONFIG", "{}") # response api config. Should be a json string: '{"key": "value"}'
+ TEXT_GENERATION_DEFAULT_MODEL = os.getenv("AI_TEXT_GENERATION_DEFAULT_MODEL", "gemini")
GEMINI_MODEL_ID = os.getenv("AI_GEMINI_MODEL_ID", "gemini-2.5-flash")
GEMINI_API_KEYS = os.getenv("AI_GEMINI_API_KEYS", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
GEMINI_BASE_URL = os.getenv("AI_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com")
GEMINI_DEFAULT_HEADERS = os.getenv("AI_GEMINI_DEFAULT_HEADERS", "{}") # default headers passed to Gemini API. Should be a json string: '{"key": "value"}'
- GEMINI_GENERATE_CONTENT_CONFIG = os.getenv("AI_GEMINI_GENERATE_CONTENT_CONFIG", "{}") # gemini generate_content config. Should be a json string: '{"key": "value"}'
GEMINI_FILES_TTL = int(os.getenv("AI_GEMINI_FILES_TTL", "172800")) # clean gemini files after 48 hours
+ OPENAI_MODEL_ID = os.getenv("AI_OPENAI_MODEL_ID", "gpt-4o")
+ OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
+ OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
+ TOOL_CALL_MODEL_ALIAS = os.getenv("AI_TOOL_CALL_MODEL_ALIAS", "tool-call")
PODCAST_SUMMARY_MODEL_ALIAS = os.getenv("PODCAST_SUMMARY_MODEL_ALIAS", "podcast-summary")
SUBTITLE_SUMMARY_MODEL_ALIAS = os.getenv("SUBTITLE_SUMMARY_MODEL_ALIAS", "subtitle-summary")
CHAT_SUMMARY_MODEL_ALIAS = os.getenv("CHAT_SUMMARY_MODEL_ALIAS", "chat-summary")