Commit 947fdce

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-03-26 09:48:00
refactor(gpt): refactor force model parsing
1 parent e38427c
Changed files (2)
src/llm/gpt.py
@@ -7,7 +7,7 @@ from pyrogram.types import Message
 
 from config import GPT, PREFIX, TEXT_LENGTH, cache
 from llm.contexts import get_conversation_contexts, get_conversations
-from llm.models import get_context_type, get_gpt_config
+from llm.models import get_context_type, get_gpt_config, parse_force_model
 from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
 from llm.tools import merge_tools_response
@@ -67,32 +67,13 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     if not is_gpt_conversation(message):
         return
 
-    # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
-    force_model = "NOT_SET"
     reply_text = ""
     if message.reply_to_message:
         reply_info = parse_msg(message.reply_to_message, silent=True)
         reply_text = reply_info["text"]
-    if startswith_prefix(info["text"], prefix=["/gpt"]) or reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
-        force_model = GPT.OPENAI_MODEL
-        if not GPT.OPENAI_API_KEY:
-            return await send2tg(client, message, texts=f"⚠️GPT未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    elif startswith_prefix(info["text"], prefix=["/gemini"]) or reply_text.startswith(f"🤖{GPT.GEMINI_MODEL_NAME}"):
-        force_model = GPT.GEMINI_MODEL
-        if not GPT.GEMINI_API_KEY:
-            return await send2tg(client, message, texts=f"⚠️Gemini未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    elif startswith_prefix(info["text"], prefix=["/ds"]) or reply_text.startswith(f"🤖{GPT.DEEPSEEK_MODEL_NAME}"):
-        force_model = GPT.DEEPSEEK_MODEL
-        if not GPT.DEEPSEEK_API_KEY:
-            return await send2tg(client, message, texts=f"⚠️DeepSeek未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    elif startswith_prefix(info["text"], prefix=["/qwen"]) or reply_text.startswith(f"🤖{GPT.QWEN_MODEL_NAME}"):
-        force_model = GPT.QWEN_MODEL
-        if not GPT.QWEN_API_KEY:
-            return await send2tg(client, message, texts=f"⚠️通义千问未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    elif startswith_prefix(info["text"], prefix=["/doubao"]) or reply_text.startswith(f"🤖{GPT.DOUBAO_MODEL_NAME}"):
-        force_model = GPT.DOUBAO_MODEL
-        if not GPT.DOUBAO_API_KEY:
-            return await send2tg(client, message, texts=f"⚠️豆包未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+
+    # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
+    force_model = parse_force_model(info["text"], reply_text)
 
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
@@ -104,6 +85,9 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     context_type = get_context_type(conversations)
     contexts = await get_conversation_contexts(client, conversations)
     config = get_gpt_config(context_type["type"], contexts, force_model)
+    if not config["client"]["api_key"]:
+        logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
+        return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
     msg = f"🤖**{config['friendly_name']}**: 思考中..."
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
src/llm/models.py
@@ -6,6 +6,7 @@ from pyrogram.types import Message
 
 from config import GPT, PREFIX, PROXY
 from messages.parser import parse_msg
+from messages.utils import startswith_prefix
 
 
 def get_context_type(conversations: list[Message]) -> dict:
@@ -26,7 +27,40 @@ def get_context_type(conversations: list[Message]) -> dict:
     return res
 
 
-def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "NOT_SET") -> dict:
+def parse_force_model(text: str, reply_text: str) -> str:
+    """Parse the force model from the text or reply text.
+
+    /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
+    """
+    force_model = ""
+    # parse from bot reply
+    if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
+        force_model = GPT.OPENAI_MODEL
+    elif reply_text.startswith(f"🤖{GPT.GEMINI_MODEL_NAME}"):
+        force_model = GPT.GEMINI_MODEL
+    elif reply_text.startswith(f"🤖{GPT.DEEPSEEK_MODEL_NAME}"):
+        force_model = GPT.DEEPSEEK_MODEL
+    elif reply_text.startswith(f"🤖{GPT.QWEN_MODEL_NAME}"):
+        force_model = GPT.QWEN_MODEL
+    elif reply_text.startswith(f"🤖{GPT.DOUBAO_MODEL_NAME}"):
+        force_model = GPT.DOUBAO_MODEL
+
+    # parse from command prefix
+    if startswith_prefix(text, prefix=["/gpt"]):
+        force_model = GPT.OPENAI_MODEL
+    elif startswith_prefix(text, prefix=["/gemini"]):
+        force_model = GPT.GEMINI_MODEL
+    elif startswith_prefix(text, prefix=["/ds"]):
+        force_model = GPT.DEEPSEEK_MODEL
+    elif startswith_prefix(text, prefix=["/qwen"]):
+        force_model = GPT.QWEN_MODEL
+    elif startswith_prefix(text, prefix=["/doubao"]):
+        force_model = GPT.DOUBAO_MODEL
+
+    return force_model
+
+
+def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "") -> dict:
     """Get GPT configurations.
 
     contexts:
@@ -47,7 +81,7 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "NO
 
     model = models[model_type]
     model_name = model_names[model_type]
-    force_model = model if force_model == "NOT_SET" else force_model
+    force_model = force_model or model
 
     # params for OpenAI client
     client = {  # this config is based on model type (text or image)