Commit a22c530

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-03-12 07:02:00
feat(gpt): skip online search if the model has built-in search capability
1 parent 2a66888
Changed files (3)
src/llm/gpt.py
@@ -8,8 +8,9 @@ from pyrogram.types import Message
 from config import GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_model_config_with_contexts, get_model_type
-from llm.response import merge_tools_response, send_to_gpt
+from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
+from llm.tools import merge_tools_response
 from llm.utils import BOT_TIPS, llm_cleanup_files
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -102,7 +103,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
         return
     contexts = await get_conversation_contexts(client, conversations)
     config = get_model_config_with_contexts(model_type, contexts, force_model, info)
-    msg = f"🤖{config['friendly_name']}: 思考中..."
+    msg = f"🤖**{config['friendly_name']}**: 思考中..."
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
 
src/llm/response.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import contextlib
-import copy
 import json
 
 from glom import Coalesce, glom
@@ -10,52 +9,10 @@ from openai import AsyncOpenAI
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 
 from config import GPT
-from llm.models import openrouter_hook
-from llm.prompts import add_search_results_to_prompts
-from llm.tools import add_tools, get_online_search_result
 from llm.utils import add_search_results_to_response, beautify_llm_response, beautify_model_name, extract_reasoning
 from messages.progress import modify_progress
 
 
-async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
-    """Use tool model to get function call result.
-
-    if no function call is triggered, return original config and the tool model response.
-    otherwise, return modified config and an empty response.
-
-    Returns:
-        (config, response)
-    """
-    if not GPT.TOOLS_API_KEY:
-        return config, {}
-    # tool model should be fast and cheap
-    completions = {
-        "model": GPT.TOOLS_MODEL,
-        "messages": copy.deepcopy(config["completions"]["messages"]),
-    }
-    completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
-    tools_config = {
-        "friendly_name": config["friendly_name"],
-        "client": config["client"] | {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY},
-        "completions": add_tools(completions),
-    }
-    try:
-        response = await send_to_gpt(tools_config, retry=0)
-        tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
-        if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
-            return config, response
-        args = json.loads(glom(tool_call, "function.arguments", default="{}"))
-        logger.debug(f"Online search tool call args: {args}")
-        await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
-        if tool_result := await get_online_search_result(**args):
-            config["completions"] = add_search_results_to_prompts(tool_result, config["completions"])
-            config["search_results"] = tool_result  # save search results for future use
-            return config, {}
-    except Exception as e:
-        logger.error(f"Tools_response failed: {e}")
-    return config, {}
-
-
 async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
     """Get GPT response in non-stream mode.
 
src/llm/tools.py
@@ -1,14 +1,18 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import copy
+import json
 
 from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 
 from config import GPT, PROXY, TOKEN, TZ
-from llm.prompts import modify_prompts
+from llm.models import openrouter_hook
+from llm.prompts import add_search_results_to_prompts, modify_prompts
+from llm.response import send_to_gpt
 from llm.tool_scheme import ONLINE_SEARCH
+from messages.progress import modify_progress
 from networking import hx_req
 from utils import nowdt
 
@@ -118,3 +122,44 @@ def remove_tool(params: dict, tool_name: str) -> dict:
         params.pop("tools", None)
         params.pop("tool_choice", None)
     return params
+
+
+async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
+    """Use tool model to get function call result.
+
+    if no function call is triggered, return original config and the tool model response.
+    otherwise, return modified config and an empty response.
+
+    Returns:
+        (config, response)
+    """
+    if not GPT.TOOLS_API_KEY:
+        return config, {}
+    if any(x in config["completions"]["model"].lower() for x in ["search", "搜索"]):  # skip search model
+        return config, {}
+    # tool model should be fast and cheap
+    completions = {
+        "model": GPT.TOOLS_MODEL,
+        "messages": copy.deepcopy(config["completions"]["messages"]),
+    }
+    completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
+    tools_config = {
+        "friendly_name": config["friendly_name"],
+        "client": config["client"] | {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY},
+        "completions": add_tools(completions),
+    }
+    try:
+        response = await send_to_gpt(tools_config, retry=0)
+        tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
+        if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
+            return config, response
+        args = json.loads(glom(tool_call, "function.arguments", default="{}"))
+        logger.debug(f"Online search tool call args: {args}")
+        await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
+        if tool_result := await get_online_search_result(**args):
+            config["completions"] = add_search_results_to_prompts(tool_result, config["completions"])
+            config["search_results"] = tool_result  # save search results for future use
+            return config, {}
+    except Exception as e:
+        logger.error(f"Tools_response failed: {e}")
+    return config, {}