Commit 3d3fdb8

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-10 12:44:09
refactor(gpt): remove tool from contexts after using it
1 parent ef84345
Changed files (3)
src/llm/response.py
@@ -9,6 +9,7 @@ from openai import AsyncOpenAI
 
 from config import ENABLE, GPT, TZ
 from llm.tool_call import get_online_search_result
+from llm.tool_scheme import remove_tool
 from llm.utils import BOT_TIPS, change_system_prompt
 from messages.progress import modify_progress
 from utils import nowdt
@@ -29,7 +30,7 @@ async def get_gpt_response(config: dict, retry: int = 0, **kwargs) -> dict[str,
         if error["retry"]:
             return await get_gpt_response(config, retry=retry + 1, **kwargs)
         if not error["error"]:
-            return await parse_tool_call(openai, config, resp, **kwargs)
+            return await parse_tool_call(config, resp, retry, **kwargs)
     except Exception as e:
         error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
         logger.error(error)
@@ -60,7 +61,7 @@ async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
     return error_result
 
 
-async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **kwargs) -> dict[str, str]:
+async def parse_tool_call(config: dict, response: dict, retry: int = 0, **kwargs) -> dict[str, str]:
     """Parse tool call.
 
     Returns:
@@ -78,17 +79,12 @@ async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **k
                 logger.debug(f"Online search tool call args: {args}")
                 await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
                 tool_result = await get_online_search_result(**args)
-                contexts = change_system_prompt(
+                config["completions"]["messages"] = change_system_prompt(
                     config["completions"]["messages"],
-                    f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [nytimes.com](https://nytimes.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
+                    f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [exaplme.com](https://www.exaplme.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
                 )
-                resp = await openai.chat.completions.create(
-                    model=config["completions"]["model"],
-                    messages=contexts,  # type: ignore
-                    temperature=config["completions"]["temperature"],
-                )
-                response = resp.model_dump()
-                logger.debug(response)
+                config["completions"] = remove_tool(config["completions"], "get_online_search_result")
+                return await get_gpt_response(config, retry, **kwargs)
         content = glom(response, "choices.0.message.content", default="") or ""
         reasoning = glom(response, "choices.0.message.reasoning", default="") or ""
         res = {"content": content.strip(), "reasoning": reasoning.strip(), "bot_msg_prefix": config["bot_msg_prefix"]}
src/llm/tool_call.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 
@@ -21,8 +22,13 @@ async def get_online_search_result(query: str) -> list[dict]:
             stream=False,
             tools=tools,  # type: ignore
         )
-        res = response.choices[0].message.model_dump().get("tool_calls", [])
-        return next((x["search_result"] for x in res if x.get("search_result")), [])
+        tool_calls = glom(response.model_dump(), "choices.0.message.tool_calls", default=[]) or []
+        results = next((x["search_result"] for x in tool_calls if x.get("search_result")), [])
+        for x in results:
+            x.pop("icon", None)
+            x.pop("index", None)
+            x.pop("refer", None)
+        return results
     except Exception as e:
         logger.error(e)
         return []
src/llm/tool_scheme.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+
 from config import ENABLE, TZ
 from llm.utils import change_system_prompt
 from utils import nowdt
@@ -37,3 +38,21 @@ def get_tools(contexts: list[dict]) -> tuple[dict, list[dict]]:
         contexts = change_system_prompt(contexts, f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}")
     tools_params = {"tools": tools, "tool_choice": "auto"} if tools else {}
     return tools_params, contexts
+
+
+def remove_tool(params: dict, tool_name: str) -> dict:
+    """Remove tool from contexts.
+
+    Returns: list[dict]
+    """
+    keep_tools = []
+    for tool in params.get("tools", []):
+        if tool.get("function", {}).get("name") != tool_name:
+            keep_tools.append(tool)
+
+    if keep_tools:
+        params["tools"] = keep_tools
+    else:
+        params.pop("tools", None)
+        params.pop("tool_choice", None)
+    return params