Commit 87c7c14

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-09-01 06:08:33
fix(gpt): support reply to message of a custom model
1 parent 1194acf
Changed files (2)
src/llm/gpt.py
@@ -48,6 +48,7 @@ async def gpt_response(
     message: Message,
     *,
     custom_model_id: str = "",
+    custom_model_name: str = "",
     enable_tools: bool = True,
     **kwargs,
 ) -> dict:
@@ -60,6 +61,7 @@ async def gpt_response(
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
         custom_model_id (str, optional): Custom model id.
+        custom_model_name (str, optional): Custom model name.
         enable_tools (bool, optional): Whether to enable tools. Defaults to True.
 
     Returns:
@@ -77,9 +79,11 @@ async def gpt_response(
     if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GENIMG) and not message.reply_to_message:
         await send2tg(client, message, texts=TEXT2IMG_HELP, **kwargs)
         return {}
-    model_id, resp_modality = get_model_id(info, message)
+    model_id, is_custom_id, resp_modality = get_model_id(info, message)
     if not model_id:
         return {}
+    if is_custom_id:
+        custom_model_id = model_id
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -98,7 +102,7 @@ async def gpt_response(
         await send2tg(client, message, texts=f"⚠️不支持自定义模型: {custom_model_id}\n\n⚙️支持自定义模型列表:\n{'\n'.join(allowed_model_ids)}", **kwargs)
         return {}
     if custom_model_id.lower() in [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS)]:
-        return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_id, enable_tools=enable_tools, **kwargs)
+        return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_name or custom_model_id, enable_tools=enable_tools, **kwargs)
     if model_id == GEMINI.TEXT_MODEL and not custom_model_id:
         return await gemini_chat_completion(client, message, enable_tools=enable_tools, **kwargs)
 
src/llm/models.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import os
+import re
 
 from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
@@ -25,39 +26,39 @@ def get_context_type(conversations: list[Message]) -> str:
     return context_type
 
 
-def get_model_id(minfo: dict, message: Message) -> tuple[str, str]:
+def get_model_id(minfo: dict, message: Message) -> tuple[str, bool, str]:
     """Get model id with response modality.
 
     Returns:
-        (model_id, response_modality)
+        (model_id, is_custom_id, response_modality)
     """
     # to avoid potential infinitely loop,
     # we do not respond to bot message & GPT responses.
     if minfo["is_bot"]:
-        return "", ""
+        return "", False, ""
     if BOT_TIPS in minfo["text"]:
-        return "", ""
+        return "", False, ""
 
     model_id, response_modality = get_model_id_from_envars(minfo)
     if model_id:
-        return model_id, response_modality
+        return model_id, False, response_modality
 
-    model_id, response_modality = get_model_id_from_prefix(minfo)
+    model_id, is_custom_id, response_modality = get_model_id_from_prefix(minfo)
     if not model_id:
-        return "", ""
+        return "", is_custom_id, ""
 
     # early return for non-text generation
     if response_modality != "text":
-        return model_id, response_modality
+        return model_id, is_custom_id, response_modality
 
     # check if we need to fallback to omini model
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)  # {"type": "text", "error": None}  # text, image
     if context_type == "text":  # no need to fallback if context type is text
-        return model_id, response_modality
+        return model_id, is_custom_id, response_modality
 
     if context_type in ["video", "audio", "voice"]:  # currently, only Gemini supports audio/video
-        return GEMINI.TEXT_MODEL, "text"
+        return GEMINI.TEXT_MODEL, is_custom_id, "text"
 
     if (
         (model_id == GPT.OPENAI_MODEL and not GPT.OPENAI_ACCEPT_IMAGE)
@@ -79,9 +80,9 @@ def get_model_id(minfo: dict, message: Message) -> tuple[str, str]:
         text_providers, _ = enabled_providers()
         # prefer gemini if OMNI_PROVIDER is not set
         model_id = omni_providers.get(GPT.OMNI_PROVIDER.lower()) or GEMINI.TEXT_MODEL or omni_providers[text_providers[0]]
-        return model_id, "text"
+        return model_id, is_custom_id, "text"
 
-    return model_id, response_modality
+    return model_id, is_custom_id, response_modality
 
 
 def get_model_id_from_envars(minfo: dict) -> tuple[str, str]:
@@ -130,26 +131,30 @@ def get_model_id_from_envars(minfo: dict) -> tuple[str, str]:
     return "", ""
 
 
-def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
+def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
     text_providers, img_providers = enabled_providers()
+    model_id = ""
+    resp_modality = "text"
+    # start with /prefix
     if startswith_prefix(minfo["text"], prefix="/gpt") and "openai" in text_providers:
-        return GPT.OPENAI_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/gemini") and "gemini" in text_providers:
-        return GEMINI.TEXT_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/ds") and "deepseek" in text_providers:
-        return GPT.DEEPSEEK_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/doubao") and "doubao" in text_providers:
-        return GPT.DOUBAO_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/qwen") and "qwen" in text_providers:
-        return GPT.QWEN_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/kimi") and "kimi" in text_providers:
-        return GPT.KIMI_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
-        return GPT.GROK_MODEL, "text"
-    if startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
-        return GEMINI.IMG_MODEL, "image"
+        model_id = GPT.OPENAI_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/gemini") and "gemini" in text_providers:
+        model_id = GEMINI.TEXT_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/ds") and "deepseek" in text_providers:
+        model_id = GPT.DEEPSEEK_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/doubao") and "doubao" in text_providers:
+        model_id = GPT.DOUBAO_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/qwen") and "qwen" in text_providers:
+        model_id = GPT.QWEN_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/kimi") and "kimi" in text_providers:
+        model_id = GPT.KIMI_MODEL
+    elif startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
+        model_id = GPT.GROK_MODEL
+    elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
+        model_id = GEMINI.IMG_MODEL
+        resp_modality = "image"
     # start with /ai, auto detect model_id
-    if startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
+    elif startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
         providers = {
             "openai": GPT.OPENAI_MODEL,
             "deepseek": GPT.DEEPSEEK_MODEL,
@@ -161,26 +166,30 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
         }
         # prefer gemini if DEFAULT_PROVIDER is not set
         model_id = providers.get(GPT.DEFAULT_PROVIDER.lower()) or GEMINI.TEXT_MODEL or providers[text_providers[0]]
-        return model_id, "text"
+    if model_id:
+        return model_id, False, resp_modality
 
     # is replying to AI response message
     if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.OPENAI_MODEL_NAME}:{BOT_TIPS}") and "openai" in text_providers:
-        return GPT.OPENAI_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.TEXT_MODEL_NAME}:{BOT_TIPS}") and "gemini" in text_providers:
-        return GEMINI.TEXT_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DEEPSEEK_MODEL_NAME}:{BOT_TIPS}") and "deepseek" in text_providers:
-        return GPT.DEEPSEEK_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DOUBAO_MODEL_NAME}:{BOT_TIPS}") and "doubao" in text_providers:
-        return GPT.DOUBAO_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.QWEN_MODEL_NAME}:{BOT_TIPS}") and "qwen" in text_providers:
-        return GPT.QWEN_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.KIMI_MODEL_NAME}:{BOT_TIPS}") and "kimi" in text_providers:
-        return GPT.KIMI_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
-        return GPT.GROK_MODEL, "text"
-    if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
-        return GEMINI.IMG_MODEL, "image"
-    return "", ""
+        model_id = GPT.OPENAI_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.TEXT_MODEL_NAME}:{BOT_TIPS}") and "gemini" in text_providers:
+        model_id = GEMINI.TEXT_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DEEPSEEK_MODEL_NAME}:{BOT_TIPS}") and "deepseek" in text_providers:
+        model_id = GPT.DEEPSEEK_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DOUBAO_MODEL_NAME}:{BOT_TIPS}") and "doubao" in text_providers:
+        model_id = GPT.DOUBAO_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.QWEN_MODEL_NAME}:{BOT_TIPS}") and "qwen" in text_providers:
+        model_id = GPT.QWEN_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.KIMI_MODEL_NAME}:{BOT_TIPS}") and "kimi" in text_providers:
+        model_id = GPT.KIMI_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
+        model_id = GPT.GROK_MODEL
+    elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
+        model_id = GEMINI.IMG_MODEL
+        resp_modality = "image"
+    elif matched := re.match(rf"^🤖(.*?):{BOT_TIPS}", minfo["reply_text"]):
+        return matched.group(1).lower(), True, "text"
+    return model_id, False, resp_modality
 
 
 def get_gpt_config(model_id: str = "") -> dict: