Commit af05ad8

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-03 03:43:50
feat(llm): add online search tool call
1 parent 3ba13c8
src/llm/response.py
@@ -1,15 +1,21 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import json
+
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
+from openai.types.chat import ChatCompletion
 
-from config import PROXY
+from config import PROXY, TZ
+from llm.tool_call import get_online_search_result
+from llm.tool_scheme import ONLINE_SEARCH
+from llm.utils import change_system_prompt
 from messages.progress import modify_progress
+from utils import nowdt
 
 
 async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
     """Get GPT response for text model."""
-    response = f"🤖{config['friendly_name']}请求失败, 请稍后重试."
     logger.trace(contexts)
     try:
         openai = AsyncOpenAI(
@@ -18,16 +24,53 @@ async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
             timeout=config["timeout"],
             http_client=DefaultAsyncHttpxClient(proxy=PROXY.GPT),
         )
+        contexts = change_system_prompt(contexts, f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}")
         resp = await openai.chat.completions.create(
             model=config["model"],
             messages=contexts,  # type: ignore
             temperature=config["temperature"],
+            tools=ONLINE_SEARCH,  # type: ignore
+            tool_choice="auto",
         )
-        if choices := resp.model_dump().get("choices", []):
-            response = choices[0].get("message", {}).get("content")
+        return await parse_tool_call(openai, config, resp, contexts, **kwargs)
     except Exception as e:
         error = f"🤖{config['friendly_name']}请求失败, 请稍后重试.\n{e}"
         logger.error(f"GPT request failed: {e}")
         await modify_progress(text=error, force_update=True, **kwargs)
         return error
-    return response
+
+
+async def parse_tool_call(
+    openai: AsyncOpenAI,
+    config: dict,
+    response: ChatCompletion,
+    contexts: list[dict],
+    **kwargs,
+) -> str:
+    if not (choices := response.model_dump().get("choices", [])):
+        return ""
+    try:
+        choice = choices[0]
+        if choice.get("message", {}).get("tool_calls", []):
+            tool_call = choice["message"]["tool_calls"][0]
+            args = json.loads(tool_call.get("function", {}).get("arguments", ""))
+            tool_result = {}
+            if tool_call.get("function", {}).get("name", "") == "get_online_search_result":
+                logger.debug(f"Online search tool call args: {args}")
+                await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
+                tool_result = await get_online_search_result(**args)
+            contexts = change_system_prompt(
+                contexts,
+                f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [nytimes.com](https://nytimes.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
+            )
+            response = await openai.chat.completions.create(
+                model=config["model"],
+                messages=contexts,  # type: ignore
+                temperature=config["temperature"],
+            )
+        if choices := response.model_dump().get("choices", []):
+            return choices[0].get("message", {}).get("content", "")
+    except Exception as e:
+        logger.error(f"Parse tool call failed: {e}")
+        raise
+    return ""
src/llm/tool_call.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from loguru import logger
+from openai import AsyncOpenAI, DefaultAsyncHttpxClient
+
+from config import GPT, PROXY
+
+
+async def get_online_search_result(query: str) -> list[dict]:
+    try:
+        client = AsyncOpenAI(
+            api_key=GPT.SEARCH_API_KEY,
+            base_url=GPT.SEARCH_BASE_URL,
+            http_client=DefaultAsyncHttpxClient(proxy=PROXY.GPT),
+        )
+        tools = [{"type": "web_search", "web_search": {"enable": True, "search_query": query, "search_result": True}}]
+        response = await client.chat.completions.create(
+            model="web-search-pro",
+            messages=[{"role": "user", "content": query}],
+            temperature=0,
+            stream=False,
+            tools=tools,  # type: ignore
+        )
+        res = response.choices[0].message.model_dump().get("tool_calls", [])
+        return next((x["search_result"] for x in res if x.get("search_result")), [])
+    except Exception as e:
+        logger.error(e)
+        return []
src/llm/tool_scheme.py
@@ -0,0 +1,16 @@
+ONLINE_SEARCH = [
+    {
+        "type": "function",
+        "function": {
+            "name": "get_online_search_result",
+            "description": "",
+            "parameters": {
+                "type": "object",
+                "properties": {
+                    "query": {"description": "搜索关键词", "type": "string"},
+                },
+                "required": ["query"],
+            },
+        },
+    },
+]
src/llm/utils.py
@@ -80,3 +80,13 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
     except Exception as e:
         logger.error(f"Error in count_tokens: {e}")
         return 0
+
+
+def change_system_prompt(context: list[dict], prompt: str) -> list[dict]:
+    if not context:
+        return [{"role": "system", "content": prompt}]
+    if context[0].get("role") == "system":
+        context[0]["content"] = prompt
+        return context
+    context.insert(0, {"role": "system", "content": prompt})
+    return context
src/config.py
@@ -125,6 +125,9 @@ class GPT:
     TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "gpt-4o")  # custom name
     IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "gpt-4o")
     VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "glm-4v-plus")
+    SEARCH_API_KEY = os.getenv("GPT_SEARCH_API_KEY", "")  # online search (currently, we use GLM)
+    SEARCH_BASE_URL = os.getenv("GPT_SEARCH_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
+    SEARCH_MODEL = os.getenv("GPT_SEARCH_MODEL", "web-search-pro")
     TEXT_TIMEOUT = os.getenv("GPT_TEXT_TIMEOUT", "15")
     IMAGE_TIMEOUT = os.getenv("GPT_IMAGE_TIMEOUT", "30")
     VIDEO_TIMEOUT = os.getenv("GPT_VIDEO_TIMEOUT", "30")