Commit 673b698

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-27 16:31:45
feat(gpt): skip send a new request if tool_model is the same as the current model
1 parent 10d501a
Changed files (2)
src/llm/gpt.py
@@ -97,8 +97,13 @@ async def gpt_response(client: Client, message: Message, **kwargs):
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    config = await merge_tools_response(config, **kwargs)
-    response = await send_to_gpt(config, **kwargs)
+
+    config, tool_response = await merge_tools_response(config, **kwargs)
+    # skip send a new request if tool_model is the same as the current model
+    if tool_response and config["completions"]["model"] == GPT.TOOLS_MODEL:
+        response = tool_response
+    else:
+        response = await send_to_gpt(config, **kwargs)
     if content := response.get("content"):
         if reasoning := response.get("reasoning"):
             content = f"{reasoning}\n\n{content}"
src/llm/response.py
@@ -17,15 +17,20 @@ from messages.progress import modify_progress
 from utils import number_to_emoji
 
 
-async def merge_tools_response(config: dict, **kwargs) -> dict:
-    """Use tool model to get function call result."""
+async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
+    """Use tool model to get function call result.
+
+    if no function call is triggered, return original config and the tool model response.
+    otherwise, return modified config and an empty response.
+
+    Returns:
+        (config, response)
+    """
     if not GPT.TOOLS_API_KEY:
-        return config
+        return config, {}
     # tool model should be fast and cheap
     completions = {
         "model": GPT.TOOLS_MODEL,
-        "temperature": 0,
-        "max_tokens": 1024,
         "messages": copy.deepcopy(config["completions"]["messages"]),
     }
     completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
@@ -38,17 +43,17 @@ async def merge_tools_response(config: dict, **kwargs) -> dict:
         response = await send_to_gpt(tools_config, retry=0)
         tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
         if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
-            return config
+            return config, response
         args = json.loads(glom(tool_call, "function.arguments", default="{}"))
         logger.debug(f"Online search tool call args: {args}")
         await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
         if tool_result := await get_online_search_result(**args):
             config["completions"] = add_search_results_to_prompts(tool_result, config["completions"])
             config["search_results"] = tool_result  # save search results for future use
-            return config
+            return config, {}
     except Exception as e:
         logger.error(f"Tools_response failed: {e}")
-    return config
+    return config, {}
 
 
 async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]: