main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import re
  4
  5from loguru import logger
  6from pyrogram.types import Message
  7
  8from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
  9from config import AI, PREFIX
 10from database.kv import get_cf_kv
 11from messages.utils import startswith_prefix
 12
 13
 14# ruff: noqa: RUF002
 15async def get_text_model_configs(message: Message) -> list[dict]:
 16    r"""Get model config based on the message.
 17
 18    Model config is retrieved from CF-KV with key: {AI.TEXT_MODEL_CONFIG_KEY}
 19
 20    A sample config:
 21    {
 22    "docs": "🤖AI对话: `/ai` + 提示词\n回复消息可将其加入历史上下文\n默认使用**Gemini-2.5-Flash**模型\n\n🔄使用以下命令强制切换模型:\n/gpt: GPT-5.2\n/gemini: Gemini-2.5-Flash\n/g3: Gemini-3-Flash (不支持网络搜索)\n/grok: Grok-4\n/claude: Claude-Opus-4.5\n/doubao: Doubao-Seed-1.8\n/ds: DeepSeek-R1\n/qwen: Qwen3-Max\n/kimi: Kimi-K2\n/glm: GLM-4.7\n/mimo: MiMo-V2-Flash",
 23    "gemini": {
 24        "common_config": {
 25            "api_type": "gemini",
 26            "gemini_base_url": "https://generativelanguage.googleapis.com",
 27            "gemini_api_keys": "key1,key2,key3...",
 28        },
 29        "models": [
 30            {
 31                "model_id": "gemini-3-flash-preview",
 32                "model_name": "Gemini-3-Flash",
 33                "gemini_generate_content_config": {
 34                    "max_output_tokens": 65536,
 35                    "thinking_config": {"include_thoughts": true, "thinking_level": "high"},
 36                    "tools": [{"url_context": {}}, {"code_execution": {}}]
 37                }
 38            },
 39            {
 40                "model_name": "Gemini-2.5-Flash",
 41                "model_id": "gemini-2.5-flash-preview-09-2025",
 42                "gemini_generate_content_config": {
 43                    "max_output_tokens": 65536,
 44                    "thinking_config": {"include_thoughts": true, "thinking_budget": 24576},
 45                    "tools": [{"google_search": {}}, {"url_context": {}}, {"code_execution": {}}]
 46                }
 47            },
 48        ]
 49    },
 50    "gpt": {
 51        "common_config": {
 52            "api_type": "gemini",
 53            "gemini_base_url": "https://generativelanguage.googleapis.com",
 54            "gemini_api_keys": "key1,key2,key3...",
 55        },
 56        "models": [
 57            {
 58                "model_id": "gpt-4o",
 59                "model_name": "GPT-4o",
 60                "api_type": "openai_chat",
 61                "openai_base_url": "https://api.openai.com/v1",
 62                "openai_api_keys": "key1,key2,key3...",
 63                "openai_completions_config": {
 64                    "temperature": 1.0,
 65                    "max_completion_tokens": 4096
 66                }
 67            },
 68            {
 69                "model_id": "gpt-5.2",
 70                "model_name": "GPT-5.2",
 71                "api_type": "openai_responses",
 72                "cache_response_ttl": 86400,
 73                "openai_base_url": "https://gateway.helicone.ai/v1",
 74                "openai_api_keys": "key1,key2,key3,...",
 75                "openai_default_headers": {
 76                    "helicone-auth": "Bearer HELICONE_API_KEY",
 77                    "helicone-target-url": "https://api.openai.com"
 78                },
 79                "openai_responses_config": {
 80                    "reasoning": { "effort": "high" },
 81                    "max_output_tokens": 4096,
 82                    "tools": [ { "type": "web_search_preview","search_context_size": "high" } ]
 83                }
 84            }
 85        ]
 86    }
 87    "tool_call_model": {
 88        "models": [
 89            {
 90                "model_id": "gpt-4o-mini",
 91                "model_name": "Web Search",
 92                "api_type": "openai_chat"
 93                "openai_base_url": "https://api.openai.com/v1",
 94                "openai_api_keys": "key1,key2,key3"
 95            }
 96        ]
 97    }
 98
 99
100    Suppose this message is:
101        Message(text="/ai hello") -> use `default` as model identifier
102        Message(text="/ai @gpt-4.1 hello") -> use `gpt-4.1` as model identifier
103
104    Reply to a message:
105        Message(text="🤖Gemini-2.5-Flash:(回复以继续)\nHello") -> find the model_alias via model_name=`Gemini-2.5-Flash`
106        Message(text="🤖GPT-4o:(回复以继续)\nHello") -> find the model_alias via model_name=`GPT-4o`
107
108    Returns:
109        [{
110            "model_id": "gpt-4o",
111            "model_name": "GPT-4o",
112            "openai_api_type": "chat",
113            "openai_base_url": "https://api.openai.com/v1",
114            "openai_api_keys": "key1,key2,...",
115            "openai_default_headers": {},
116            "openai_completions_config": {},
117            "openai_responses_config": {},
118            ....  # other fileds will also be passed to the function
119        }]
120    """
121    texts = str(message.content).strip()
122    if texts.startswith(EMOJI_TEXT_BOT) and BOT_TIPS in texts:
123        # DO NOT respond to AI responses to avoid potential infinitely loop
124        return []
125
126    # this message starts with /ai
127    if startswith_prefix(message.content, PREFIX.AI_TEXT_GENERATION):
128        prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
129        prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
130        if not prompt and not message.reply_to_message:  # no prompt & no reply_msg
131            await message.reply_text(text=await text_generation_docs(), quote=True)
132            return []
133        if matched := re.match(rf"^{PREFIX.AI_TEXT_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts):  # match /ai @custom_model_id
134            model_alias = matched.group(1).strip()
135            return await get_config_by_model_alias(model_alias)
136        return await get_config_by_model_alias(AI.TEXT_GENERATION_DEFAULT_MODEL)
137
138    # this message is not /ai, try to find model id from reply_message
139    reply_msg = message.reply_to_message
140    if not isinstance(reply_msg, Message):
141        return []
142
143    if matched := re.match(rf"^{EMOJI_TEXT_BOT}(.*?):{BOT_TIPS}", str(reply_msg.content)):
144        model_name = matched.group(1).strip()
145        return await get_config_by_model_name(model_name)
146    return []
147
148
149async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bool = True) -> list[dict]:
150    """Get model config by model_alias.
151
152    Returns:
153        model_config
154    """
155    kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
156
157    if config := kv.get(model_alias, {}):
158        common_config = config.get("common_config", {})
159        return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
160
161    if not fallback_to_default:
162        return []
163
164    logger.warning(f"Model Alias `{model_alias}` is not configured in KV, fallback to default config")
165    return [
166        {
167            "model_id": AI.GEMINI_MODEL_ID,
168            "model_name": AI.GEMINI_MODEL_ID,
169            "api_type": "gemini",
170            "gemini_base_url": AI.GEMINI_BASE_URL,
171            "gemini_api_keys": AI.GEMINI_API_KEYS,
172        },
173        {
174            "model_id": AI.ANTHROPIC_MODEL_ID,
175            "model_name": AI.ANTHROPIC_MODEL_ID,
176            "api_type": "anthropic",
177            "anthropic_base_url": AI.ANTHROPIC_BASE_URL,
178            "anthropic_api_keys": AI.ANTHROPIC_API_KEYS,
179        },
180        {
181            "model_id": AI.OPENAI_MODEL_ID,
182            "model_name": AI.OPENAI_MODEL_ID,
183            "api_type": "openai_chat",
184            "openai_base_url": AI.OPENAI_BASE_URL,
185            "openai_api_keys": AI.OPENAI_API_KEYS,
186        },
187    ]
188
189
190async def get_config_by_model_name(model_name: str) -> list[dict]:
191    """Get model config by model_name.
192
193    Returns:
194        model_config
195    """
196    kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
197    model_configs = []
198    for alias, config in kv.items():
199        if not isinstance(config, dict):
200            continue
201        if alias in {AI.TOOL_CALL_MODEL_ALIAS, AI.PODCAST_SUMMARY_MODEL_ALIAS, AI.CHAT_SUMMARY_MODEL_ALIAS, AI.SUBTITLE_SUMMARY_MODEL_ALIAS}:
202            continue
203        common_config = config.get("common_config", {})
204        for model in config.get("models", []):
205            model_config = deep_merge(common_config, model)
206            if model_config.get("model_name", "") == model_name:
207                model_configs.append(model_config)
208    return model_configs