main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import re
4
5from loguru import logger
6from pyrogram.client import Client
7from pyrogram.types import Message
8
9from ai.texts.contexts import context_types
10from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
11from config import AI, PREFIX
12from database.kv import get_cf_kv
13from messages.utils import startswith_prefix
14from utils import strings_list
15
16
17# ruff: noqa: RUF002
18async def get_text_model_configs(message: Message) -> list[dict]:
19 r"""Get model config based on the message.
20
21 Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
22
23 A sample config:
24 {
25 "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
26 "gemini": {
27 "common_config": {
28 "api_type": "gemini",
29 "gemini_base_url": "https://generativelanguage.googleapis.com",
30 "gemini_api_keys": "key1,key2,key3...",
31 },
32 "models": [
33 {
34 "model_id": "gemini-3-flash-preview",
35 "model_name": "Gemini-3-Flash",
36 "gemini_generate_content_config": {
37 "max_output_tokens": 65536,
38 "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
39 "tools": [{"url_context": {}}, {"code_execution": {}}]
40 }
41 },
42 {
43 "model_name": "Gemini-2.5-Flash",
44 "model_id": "gemini-2.5-flash-preview-09-2025",
45 "gemini_generate_content_config": {
46 "max_output_tokens": 65536,
47 "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
48 "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
49 }
50 },
51 ]
52 },
53 "gpt": {
54 "common_config": {
55 "api_type": "gemini",
56 "gemini_base_url": "https://generativelanguage.googleapis.com",
57 "gemini_api_keys": "key1,key2,key3...",
58 },
59 "models": [
60 {
61 "model_id": "gpt-4o",
62 "model_name": "GPT-4o",
63 "api_type": "openai_chat",
64 "openai_base_url": "https://api.openai.com/v1",
65 "openai_api_keys": "key1,key2,key3...",
66 "openai_completions_config": {
67 "temperature": 1.0,
68 "max_completion_tokens": 4096
69 }
70 },
71 {
72 "model_id": "gpt-5.2",
73 "model_name": "GPT-5.2",
74 "api_type": "openai_responses",
75 "cache_response_ttl": 86400,
76 "openai_base_url": "https://gateway.helicone.ai/v1",
77 "openai_api_keys": "key1,key2,key3,...",
78 "openai_default_headers": {
79 "helicone-auth": "Bearer HELICONE_API_KEY",
80 "helicone-target-url": "https://api.openai.com"
81 },
82 "openai_responses_config": {
83 "reasoning": { "effort": "high" },
84 "max_output_tokens": 4096,
85 "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
86 }
87 }
88 ]
89 }
90 "tool_call_model": {
91 "models": [
92 {
93 "model_id": "gpt-4o-mini",
94 "model_name": "Web Search",
95 "api_type": "openai_chat"
96 "openai_base_url": "https://api.openai.com/v1",
97 "openai_api_keys": "key1,key2,key3"
98 }
99 ]
100 }
101
102
103 Suppose this message is:
104 Message(text="/ai hello") -> use `default` as model identifier
105 Message(text="/ai @gpt-4.1 hello") -> use `gpt-4.1` as model identifier
106
107 Reply to a message:
108 Message(text="🤖Gemini-2.5-Flash:(回复以继续)\nHello") -> find the model_alias via model_name=`Gemini-2.5-Flash`
109 Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
110
111 Returns:
112 [{
113 "model_id": "gpt-4o",
114 "model_name": "GPT-4o",
115 "openai_api_type": "chat",
116 "openai_base_url": "https://api.openai.com/v1",
117 "openai_api_keys": "key1,key2,...",
118 "openai_default_headers": {},
119 "openai_completions_config": {},
120 "openai_responses_config": {},
121 .... # other fileds will also be passed to the function
122 }]
123 """
124 texts = str(message.content).strip()
125 if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
126 # DO NOT respond to AI responses to avoid potential infinitely loop
127 return []
128
129 # this message starts with /ai
130 if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
131 prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
132 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
133 if not prompt and not message.reply_to_message: # no prompt & no reply_msg
134 await message.reply_text(text=await text_generation_docs(), quote=True)
135 return []
136 if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /ai @custom_model_id
137 model_alias = matched.group(1).strip()
138 return await get_config_by_model_alias(model_alias)
139 return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
140
141 # this message is not /ai, try to find model id from reply_message
142 reply_msg = message.reply_to_message
143 if not isinstance(reply_msg, Message):
144 return []
145
146 if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
147 model_name = matched.group(1).strip()
148 return await get_config_by_model_name(model_name)
149 return []
150
151
152async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
153 """Get model config by model_alias.
154
155 Returns:
156 model_config
157 """
158 kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
159 if model_alias in kv:
160 common = kv[model_alias].get("common_config", {})
161 configs = []
162 for config in kv[model_alias].get("models", []):
163 merged = deep_merge(common, config)
164 shuffle = bool(merged.get("strategy", "fallback") == "load-balance") # load-balance or fallback
165 for model_id in strings_list(merged["model_id"], shuffle=shuffle):
166 merged["model_id"] = model_id
167 configs.append(merged.copy())
168 return configs
169
170 if not fallback_to_default:
171 return []
172
173 logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
174 return [
175 {
176 "model_id": AI.GEMINI_MODEL_ID,
177 "model_name": AI.GEMINI_MODEL_ID,
178 "api_type": "gemini",
179 "gemini_base_url": AI.GEMINI_BASE_URL,
180 "gemini_api_keys": AI.GEMINI_API_KEYS,
181 },
182 {
183 "model_id": AI.ANTHROPIC_MODEL_ID,
184 "model_name": AI.ANTHROPIC_MODEL_ID,
185 "api_type": "anthropic",
186 "anthropic_base_url": AI.ANTHROPIC_BASE_URL,
187 "anthropic_api_keys": AI.ANTHROPIC_API_KEYS,
188 },
189 {
190 "model_id": AI.OPENAI_MODEL_ID,
191 "model_name": AI.OPENAI_MODEL_ID,
192 "api_type": "openai_chat",
193 "openai_base_url": AI.OPENAI_BASE_URL,
194 "openai_api_keys": AI.OPENAI_API_KEYS,
195 },
196 ]
197
198
199async def get_config_by_model_name(model_name: str) -> list[dict]:
200 """Get model config by model_name.
201
202 Returns:
203 model_config
204 """
205 kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
206 model_configs = []
207 for alias, config in kv.items():
208 if not isinstance(config, dict):
209 continue
210 if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
211 continue
212 common_config = config.get("common_config", {})
213 for model in config.get("models", []):
214 model_config = deep_merge(common_config, model)
215 if model_config.get("model_name", "") != model_name:
216 continue
217 shuffle = bool(model_config.get("strategy", "fallback") == "load-balance") # load-balance or fallback
218 for model_id in strings_list(model_config["model_id"], shuffle=shuffle):
219 model_config["model_id"] = model_id
220 model_configs.append(model_config.copy())
221 return model_configs
222
223
224async def reorder_model_configs(client: Client, message: Message, configs: list[dict], params: dict) -> list[dict]:
225 """Reorder model configs by strategy.
226
227 prefer gemini model if types have youtube
228 then prefer audio model if types have audio
229 then prefer video model if types have video
230 then prefer image model if types have image
231 then prefer file model if types have file
232 then prefer text model if types have text
233
234 Returns:
235 model_configs
236 """
237 types = await context_types(client, message, params.get("additional_contexts", []))
238 if not any((types.get("youtube"), types.get("audio"), types.get("video"), types.get("image"), types.get("file"))):
239 return configs # text only
240
241 def is_preferred(config: dict) -> bool:
242 api_type = config.get("api_type", "")
243
244 if api_type == "gemini":
245 return True
246
247 is_openai = api_type.startswith("openai")
248
249 # youtube > audio > video > image > file
250 if types.get("youtube"):
251 return False # only Gemini can handle YouTube
252 if types.get("audio"):
253 return is_openai and config.get("openai_allow_audio", False)
254 if types.get("video"):
255 return is_openai and config.get("openai_allow_video", False)
256 if types.get("image"):
257 return is_openai and config.get("openai_allow_image", False)
258 if types.get("file"):
259 return is_openai and config.get("openai_allow_file", False)
260
261 return False
262
263 preferred_configs = []
264 remaining_configs = []
265
266 for config in configs:
267 if is_preferred(config):
268 preferred_configs.append(config)
269 else:
270 remaining_configs.append(config)
271
272 return preferred_configs + remaining_configs