main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import re
  4
  5from loguru import logger
  6from pyrogram.client import Client
  7from pyrogram.types import Message
  8
  9from ai.texts.contexts import context_types
 10from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
 11from config import AI, PREFIX
 12from database.kv import get_cf_kv
 13from messages.utils import startswith_prefix
 14from utils import strings_list
 15
 16
 17# ruff: noqa: RUF002
 18async def get_text_model_configs(message: Message) -> list[dict]:
 19    r"""Get model config based on the message.
 20
 21    Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
 22
 23    A sample config:
 24    {
 25    "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
 26    "gemini": {
 27        "common_config": {
 28            "api_type": "gemini",
 29            "gemini_base_url": "https://generativelanguage.googleapis.com",
 30            "gemini_api_keys": "key1,key2,key3...",
 31        },
 32        "models": [
 33            {
 34                "model_id": "gemini-3-flash-preview",
 35                "model_name": "Gemini-3-Flash",
 36                "gemini_generate_content_config": {
 37                    "max_output_tokens": 65536,
 38                    "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
 39                    "tools": [{"url_context": {}}, {"code_execution": {}}]
 40                }
 41            },
 42            {
 43                "model_name": "Gemini-2.5-Flash",
 44                "model_id": "gemini-2.5-flash-preview-09-2025",
 45                "gemini_generate_content_config": {
 46                    "max_output_tokens": 65536,
 47                    "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
 48                    "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
 49                }
 50            },
 51        ]
 52    },
 53    "gpt": {
 54        "common_config": {
 55            "api_type": "gemini",
 56            "gemini_base_url": "https://generativelanguage.googleapis.com",
 57            "gemini_api_keys": "key1,key2,key3...",
 58        },
 59        "models": [
 60            {
 61                "model_id": "gpt-4o",
 62                "model_name": "GPT-4o",
 63                "api_type": "openai_chat",
 64                "openai_base_url": "https://api.openai.com/v1",
 65                "openai_api_keys": "key1,key2,key3...",
 66                "openai_completions_config": {
 67                    "temperature": 1.0,
 68                    "max_completion_tokens": 4096
 69                }
 70            },
 71            {
 72                "model_id": "gpt-5.2",
 73                "model_name": "GPT-5.2",
 74                "api_type": "openai_responses",
 75                "cache_response_ttl": 86400,
 76                "openai_base_url": "https://gateway.helicone.ai/v1",
 77                "openai_api_keys": "key1,key2,key3,...",
 78                "openai_default_headers": {
 79                    "helicone-auth": "Bearer HELICONE_API_KEY",
 80                    "helicone-target-url": "https://api.openai.com"
 81                },
 82                "openai_responses_config": {
 83                    "reasoning": { "effort": "high" },
 84                    "max_output_tokens": 4096,
 85                    "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
 86                }
 87            }
 88        ]
 89    }
 90    "tool_call_model": {
 91        "models": [
 92            {
 93                "model_id": "gpt-4o-mini",
 94                "model_name": "Web Search",
 95                "api_type": "openai_chat"
 96                "openai_base_url": "https://api.openai.com/v1",
 97                "openai_api_keys": "key1,key2,key3"
 98            }
 99        ]
100    }
101
102
103    Suppose this message is:
104        Message(text="/ai hello") -> use `default` as model identifier
105        Message(text="/ai @gpt-4.1 hello") -> use `gpt-4.1` as model identifier
106
107    Reply to a message:
108        Message(text="🤖Gemini-2.5-Flash:(回复以继续)\nHello") -> find the model_alias via model_name=`Gemini-2.5-Flash`
109        Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
110
111    Returns:
112        [{
113            "model_id": "gpt-4o",
114            "model_name": "GPT-4o",
115            "openai_api_type": "chat",
116            "openai_base_url": "https://api.openai.com/v1",
117            "openai_api_keys": "key1,key2,...",
118            "openai_default_headers": {},
119            "openai_completions_config": {},
120            "openai_responses_config": {},
121            ....  # other fileds will also be passed to the function
122        }]
123    """
124    texts = str(message.content).strip()
125    if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
126        # DO NOT respond to AI responses to avoid potential infinitely loop
127        return []
128
129    # this message starts with /ai
130    if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
131        prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
132        prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
133        if not prompt and not message.reply_to_message:  # no prompt & no reply_msg
134            await message.reply_text(text=await text_generation_docs(), quote=True)
135            return []
136        if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts):  # match /ai @custom_model_id
137            model_alias = matched.group(1).strip()
138            return await get_config_by_model_alias(model_alias)
139        return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
140
141    # this message is not /ai, try to find model id from reply_message
142    reply_msg = message.reply_to_message
143    if not isinstance(reply_msg, Message):
144        return []
145
146    if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
147        model_name = matched.group(1).strip()
148        return await get_config_by_model_name(model_name)
149    return []
150
151
152async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
153    """Get model config by model_alias.
154
155    Returns:
156        model_config
157    """
158    kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
159    if model_alias in kv:
160        common = kv[model_alias].get("common_config", {})
161        configs = []
162        for config in kv[model_alias].get("models", []):
163            merged = deep_merge(common, config)
164            shuffle = bool(merged.get("strategy", "fallback") == "load-balance")  # load-balance or fallback
165            for model_id in strings_list(merged["model_id"], shuffle=shuffle):
166                merged["model_id"] = model_id
167                configs.append(merged.copy())
168        return configs
169
170    if not fallback_to_default:
171        return []
172
173    logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
174    return [
175        {
176            "model_id": AI.GEMINI_MODEL_ID,
177            "model_name": AI.GEMINI_MODEL_ID,
178            "api_type": "gemini",
179            "gemini_base_url": AI.GEMINI_BASE_URL,
180            "gemini_api_keys": AI.GEMINI_API_KEYS,
181        },
182        {
183            "model_id": AI.ANTHROPIC_MODEL_ID,
184            "model_name": AI.ANTHROPIC_MODEL_ID,
185            "api_type": "anthropic",
186            "anthropic_base_url": AI.ANTHROPIC_BASE_URL,
187            "anthropic_api_keys": AI.ANTHROPIC_API_KEYS,
188        },
189        {
190            "model_id": AI.OPENAI_MODEL_ID,
191            "model_name": AI.OPENAI_MODEL_ID,
192            "api_type": "openai_chat",
193            "openai_base_url": AI.OPENAI_BASE_URL,
194            "openai_api_keys": AI.OPENAI_API_KEYS,
195        },
196    ]
197
198
199async def get_config_by_model_name(model_name: str) -> list[dict]:
200    """Get model config by model_name.
201
202    Returns:
203        model_config
204    """
205    kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
206    model_configs = []
207    for alias, config in kv.items():
208        if not isinstance(config, dict):
209            continue
210        if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
211            continue
212        common_config = config.get("common_config", {})
213        for model in config.get("models", []):
214            model_config = deep_merge(common_config, model)
215            if model_config.get("model_name", "") != model_name:
216                continue
217            shuffle = bool(model_config.get("strategy", "fallback") == "load-balance")  # load-balance or fallback
218            for model_id in strings_list(model_config["model_id"], shuffle=shuffle):
219                model_config["model_id"] = model_id
220                model_configs.append(model_config.copy())
221    return model_configs
222
223
224async def reorder_model_configs(client: Client, message: Message, configs: list[dict], params: dict) -> list[dict]:
225    """Reorder model configs by strategy.
226
227    prefer gemini model if types have youtube
228    then prefer audio model if types have audio
229    then prefer video model if types have video
230    then prefer image model if types have image
231    then prefer file model if types have file
232    then prefer text model if types have text
233
234    Returns:
235        model_configs
236    """
237    types = await context_types(client, message, params.get("additional_contexts", []))
238    if not any((types.get("youtube"), types.get("audio"), types.get("video"), types.get("image"), types.get("file"))):
239        return configs  # text only
240
241    def is_preferred(config: dict) -> bool:
242        api_type = config.get("api_type", "")
243
244        if api_type == "gemini":
245            return True
246
247        is_openai = api_type.startswith("openai")
248
249        #  youtube > audio > video > image > file
250        if types.get("youtube"):
251            return False  # only Gemini can handle YouTube
252        if types.get("audio"):
253            return is_openai and config.get("openai_allow_audio", False)
254        if types.get("video"):
255            return is_openai and config.get("openai_allow_video", False)
256        if types.get("image"):
257            return is_openai and config.get("openai_allow_image", False)
258        if types.get("file"):
259            return is_openai and config.get("openai_allow_file", False)
260
261        return False
262
263    preferred_configs = []
264    remaining_configs = []
265
266    for config in configs:
267        if is_preferred(config):
268            preferred_configs.append(config)
269        else:
270            remaining_configs.append(config)
271
272    return preferred_configs + remaining_configs