main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import re
4
5from loguru import logger
6from pyrogram.types import Message
7
8from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
9from config import AI, PREFIX
10from database.kv import get_cf_kv
11from messages.utils import startswith_prefix
12
13
14# ruff: noqa: RUF002
15async def get_text_model_configs(message: Message) -> list[dict]:
16 r"""Get model config based on the message.
17
18 Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
19
20 A sample config:
21 {
22 "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
23 "gemini": {
24 "common_config": {
25 "api_type": "gemini",
26 "gemini_base_url": "https://generativelanguage.googleapis.com",
27 "gemini_api_keys": "key1,key2,key3...",
28 },
29 "models": [
30 {
31 "model_id": "gemini-3-flash-preview",
32 "model_name": "Gemini-3-Flash",
33 "gemini_generate_content_config": {
34 "max_output_tokens": 65536,
35 "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
36 "tools": [{"url_context": {}}, {"code_execution": {}}]
37 }
38 },
39 {
40 "model_name": "Gemini-2.5-Flash",
41 "model_id": "gemini-2.5-flash-preview-09-2025",
42 "gemini_generate_content_config": {
43 "max_output_tokens": 65536,
44 "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
45 "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
46 }
47 },
48 ]
49 },
50 "gpt": {
51 "common_config": {
52 "api_type": "gemini",
53 "gemini_base_url": "https://generativelanguage.googleapis.com",
54 "gemini_api_keys": "key1,key2,key3...",
55 },
56 "models": [
57 {
58 "model_id": "gpt-4o",
59 "model_name": "GPT-4o",
60 "api_type": "openai_chat",
61 "openai_base_url": "https://api.openai.com/v1",
62 "openai_api_keys": "key1,key2,key3...",
63 "openai_completions_config": {
64 "temperature": 1.0,
65 "max_completion_tokens": 4096
66 }
67 },
68 {
69 "model_id": "gpt-5.2",
70 "model_name": "GPT-5.2",
71 "api_type": "openai_responses",
72 "cache_response_ttl": 86400,
73 "openai_base_url": "https://gateway.helicone.ai/v1",
74 "openai_api_keys": "key1,key2,key3,...",
75 "openai_default_headers": {
76 "helicone-auth": "Bearer HELICONE_API_KEY",
77 "helicone-target-url": "https://api.openai.com"
78 },
79 "openai_responses_config": {
80 "reasoning": { "effort": "high" },
81 "max_output_tokens": 4096,
82 "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
83 }
84 }
85 ]
86 }
87 "tool_call_model": {
88 "models": [
89 {
90 "model_id": "gpt-4o-mini",
91 "model_name": "Web Search",
92 "api_type": "openai_chat"
93 "openai_base_url": "https://api.openai.com/v1",
94 "openai_api_keys": "key1,key2,key3"
95 }
96 ]
97 }
98
99
100 Suppose this message is:
101 Message(text="/ai hello") -> use `default` as model identifier
102 Message(text="/ai @gpt-4.1 hello") -> use `gpt-4.1` as model identifier
103
104 Reply to a message:
105 Message(text="🤖Gemini-2.5-Flash:(回复以继续)\nHello") -> find the model_alias via model_name=`Gemini-2.5-Flash`
106 Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
107
108 Returns:
109 [{
110 "model_id": "gpt-4o",
111 "model_name": "GPT-4o",
112 "openai_api_type": "chat",
113 "openai_base_url": "https://api.openai.com/v1",
114 "openai_api_keys": "key1,key2,...",
115 "openai_default_headers": {},
116 "openai_completions_config": {},
117 "openai_responses_config": {},
118 .... # other fileds will also be passed to the function
119 }]
120 """
121 texts = str(message.content).strip()
122 if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
123 # DO NOT respond to AI responses to avoid potential infinitely loop
124 return []
125
126 # this message starts with /ai
127 if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
128 prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
129 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
130 if not prompt and not message.reply_to_message: # no prompt & no reply_msg
131 await message.reply_text(text=await text_generation_docs(), quote=True)
132 return []
133 if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /ai @custom_model_id
134 model_alias = matched.group(1).strip()
135 return await get_config_by_model_alias(model_alias)
136 return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
137
138 # this message is not /ai, try to find model id from reply_message
139 reply_msg = message.reply_to_message
140 if not isinstance(reply_msg, Message):
141 return []
142
143 if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
144 model_name = matched.group(1).strip()
145 return await get_config_by_model_name(model_name)
146 return []
147
148
149async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
150 """Get model config by model_alias.
151
152 Returns:
153 model_config
154 """
155 kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
156
157 if config := kv.get(model_alias, {}):
158 common_config = config.get("common_config", {})
159 return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
160
161 if not fallback_to_default:
162 return []
163
164 logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
165 return [
166 {
167 "model_id": AI.GEMINI_MODEL_ID,
168 "model_name": AI.GEMINI_MODEL_ID,
169 "api_type": "gemini",
170 "gemini_base_url": AI.GEMINI_BASE_URL,
171 "gemini_api_keys": AI.GEMINI_API_KEYS,
172 },
173 {
174 "model_id": AI.ANTHROPIC_MODEL_ID,
175 "model_name": AI.ANTHROPIC_MODEL_ID,
176 "api_type": "anthropic",
177 "anthropic_base_url": AI.ANTHROPIC_BASE_URL,
178 "anthropic_api_keys": AI.ANTHROPIC_API_KEYS,
179 },
180 {
181 "model_id": AI.OPENAI_MODEL_ID,
182 "model_name": AI.OPENAI_MODEL_ID,
183 "api_type": "openai_chat",
184 "openai_base_url": AI.OPENAI_BASE_URL,
185 "openai_api_keys": AI.OPENAI_API_KEYS,
186 },
187 ]
188
189
190async def get_config_by_model_name(model_name: str) -> list[dict]:
191 """Get model config by model_name.
192
193 Returns:
194 model_config
195 """
196 kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
197 model_configs = []
198 for alias, config in kv.items():
199 if not isinstance(config, dict):
200 continue
201 if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
202 continue
203 common_config = config.get("common_config", {})
204 for model in config.get("models", []):
205 model_config = deep_merge(common_config, model)
206 if model_config.get("model_name", "") == model_name:
207 model_configs.append(model_config)
208 return model_configs