Commit 33b1d4a

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-11 09:34:58
refactor(gpt): use custom tool model to get function call result
1 parent d4cfff1
src/llm/gpt.py
@@ -9,7 +9,7 @@ from pyrogram.types import Message
 from config import DOWNLOAD_DIR, ENABLE, GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_model_config_with_contexts, get_model_type
-from llm.response import get_gpt_response
+from llm.response import merge_tools_response, send_to_gpt
 from llm.utils import llm_cleanup_files
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -79,7 +79,8 @@ async def gpt_response(client: Client, message: Message, **kwargs):
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    response = await get_gpt_response(config, **kwargs)
+    config = await merge_tools_response(config, **kwargs)
+    response = await send_to_gpt(config, **kwargs)
     media = []
     if reasoning := response.get("reasoning"):
         reasoning_model = f"推理模型: {response['reasoning_model']}\n\n" if response.get("reasoning_model") else ""
src/llm/models.py
@@ -4,7 +4,6 @@ from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
 from config import GPT, PROXY
-from llm.tool_scheme import get_tools
 from llm.utils import BOT_TIPS
 from messages.parser import parse_msg
 
@@ -57,12 +56,9 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
         "timeout": round(float(timeouts[model_type])),
         "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
     }
-    # get tool call params, and adjust contexts if nessary
-    tools, contexts = get_tools(contexts)
 
     # params for `openai.chat.completions.create()`
     completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
-    completions |= tools
     completions |= openrouter_hook(client["base_url"])
 
     return {
@@ -73,12 +69,16 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
     }
 
 
-def openrouter_hook(base_url: str) -> dict:
+def openrouter_hook(base_url: str, *, for_tools: bool = False) -> dict:
     """Add special parameters for OpenRouter."""
     if "openrouter" not in base_url:
         return {}
     params = {}
-    params |= {"extra_body": {"include_reasoning": True}}
-    if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
-        params["extra_body"]["models"] = models
+    if for_tools:
+        if models := [x.strip() for x in GPT.FALLBACK_TOOLS_MODELS.split(",") if x.strip()]:
+            params |= {"extra_body": {"models": models}}
+    else:
+        params |= {"extra_body": {"include_reasoning": True}}
+        if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
+            params["extra_body"]["models"] = models
     return params
src/llm/README.md
@@ -0,0 +1,13 @@
+# GPT调用流程
+
+程序主入口为 `llm/gpt.py` 的 `gpt_response` 函数。
+
+接到消息后, 首先解析出本消息的所有回复消息组成历史上下文,然后根据消息内容判断判断调用哪种类型的GPT。(文本 or 图片)
+
+目前我们使用OpenRouter接口站, 主model为 `deepseek-r1`, 备用model为 `gpt-4o`。
+
+由于`deepseek-r1` 不支持 `function call` 功能,为了联网搜索最新消息,所有我们的调用流程分为两个阶段。
+
+1. 第一阶段, 将附带`function call`的原始prompt发送给一个支持`function call`的模型 (TOOL_MODEL), 此模型会返回是否需要调用`get_online_search_result`函数以及`query`内容。我们并不关心此模型返回的`content`, 只关心是否调用`get_online_search_result`函数以及`query`内容。TOOL_MODEL模型只需要速度快且价格便宜。
+
+2. 第二阶段, 根据TOOL_MODEL的结果, 获取联网搜索结果, 将更新后的上下文和原始prompt发送给主模型 `deepseek-r1` 进行对话。
src/llm/response.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import contextlib
+import copy
 import json
 
 from glom import glom
@@ -8,18 +9,61 @@ from loguru import logger
 from openai import AsyncOpenAI
 
 from config import ENABLE, GPT, TZ
-from llm.tool_call import get_online_search_result
-from llm.tool_scheme import remove_tool
-from llm.utils import BOT_TIPS, change_system_prompt
+from llm.models import openrouter_hook
+from llm.tools import add_tools, get_online_search_result
+from llm.utils import change_system_prompt
 from messages.progress import modify_progress
 from utils import nowdt
 
 
-async def get_gpt_response(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
-    """Get GPT response for text model.
+async def merge_tools_response(config: dict, **kwargs) -> dict:
+    """Use tool model to get function call result."""
+    if not GPT.TOOLS_API_KEY:
+        return config
+    # tool model should be fast and cheap
+    completions = {
+        "model": GPT.TOOLS_MODEL,
+        "temperature": 0,
+        "max_tokens": 1024,
+        "messages": copy.deepcopy(config["completions"]["messages"]),
+    }
+    completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
+    tools_config = {
+        "friendly_name": config["friendly_name"],
+        "bot_msg_prefix": config["bot_msg_prefix"],
+        "client": {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY, "http_client": config["client"]["http_client"]},
+        "completions": add_tools(completions),
+    }
+    try:
+        response = await send_to_gpt(tools_config, retry=0)
+        tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
+        if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
+            return config
+        args = json.loads(glom(tool_call, "function.arguments", default="{}"))
+        logger.debug(f"Online search tool call args: {args}")
+        await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
+        if tool_result := await get_online_search_result(**args):
+            config["completions"]["messages"] = change_system_prompt(
+                config["completions"]["messages"],
+                f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [exaplme.com](https://www.exaplme.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}\n请再次注意, 当前的现实日期就是{nowdt(TZ):%Y-%m-%d}, 以上搜索结果并非预测中的结果",
+            )
+            return config
+    except Exception as e:
+        logger.error(f"Tools_response failed: {e}")
+    return config
+
+
+async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
+    """Get GPT response.
+
+    # See `llm/README.md` for more details.
+
+    Args:
+        config: dict, contains model configuration
+        retry: int, number of retries
 
     Returns:
-        {"content": str, "reasoning": str}
+        {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
     """
     try:
         openai = AsyncOpenAI(**config["client"])
@@ -28,16 +72,16 @@ async def get_gpt_response(config: dict, retry: int = 0, **kwargs) -> dict[str,
         resp = resp.model_dump()
         error = await parse_error(resp, retry, **kwargs)
         if error["retry"]:
-            return await get_gpt_response(config, retry=retry + 1, **kwargs)
+            return await send_to_gpt(config, retry=retry + 1, **kwargs)
         if not error["error"]:
-            return await parse_tool_call(config, resp, retry, **kwargs)
+            return await parse_response(config, resp)
     except Exception as e:
         error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
         logger.error(error)
         await modify_progress(text=error, force_update=True, **kwargs)
         if retry < GPT.MAX_RETRY:
-            return await get_gpt_response(config, retry=retry + 1, **kwargs)
-    return {"content": "", "reasoning": ""}
+            return await send_to_gpt(config, retry=retry + 1, **kwargs)
+    return {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": ""}
 
 
 async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
@@ -49,10 +93,15 @@ async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
     error_result = {"error": False, "retry": False}
     error_code = glom(resp, "error.code", default=0)
     error_msg = ""
+    content = ""
+    tool_call = {}
     with contextlib.suppress(Exception):
-        metadata = glom(resp, "error.metadata.raw", default={})
+        metadata = glom(resp, "error.metadata.raw", default="{}")
         error_msg = glom(json.loads(metadata), "error.message", default="")
-    if error_code != 0:
+        choice = glom(resp, "choices.0", default={})
+        content = glom(choice, "message.content", default="") or ""
+        tool_call = glom(choice, "message.tool_calls.0", default={})
+    if error_code != 0 or not (content or tool_call):
         logger.warning(resp)
         error_result["error"] = True
         await modify_progress(text=f"[{error_code}] {error_msg}\n重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}", force_update=True, **kwargs)
@@ -61,43 +110,29 @@ async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
     return error_result
 
 
-async def parse_tool_call(config: dict, response: dict, retry: int = 0, **kwargs) -> dict[str, str]:
-    """Parse tool call.
+async def parse_response(config: dict, response: dict) -> dict[str, str]:
+    """Parse GPT response.
 
     Returns:
-        {"content": str, "reasoning": str}
+        {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
     """
-    choice = glom(response, "choices.0", default={})
-    if choice.get("finish_reason", "") not in ["stop", "tool_calls", "length"]:
-        logger.warning(response)
-        raise  # noqa: PLE0704
     logger.debug(response)
+    choice = glom(response, "choices.0", default={})
+    if glom(choice, "message.tool_calls.0", default={}):  # this is a function call response
+        return response | {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": config["bot_msg_prefix"]}
     try:
-        if tool_call := glom(choice, "message.tool_calls.0", default={}):
-            args = json.loads(glom(tool_call, "function.arguments", default="{}"))
-            tool_result = {}
-            if glom(tool_call, "function.name", default="") == "get_online_search_result":
-                logger.debug(f"Online search tool call args: {args}")
-                await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
-                tool_result = await get_online_search_result(**args)
-                config["completions"]["messages"] = change_system_prompt(
-                    config["completions"]["messages"],
-                    f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [exaplme.com](https://www.exaplme.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}\n请再次注意, 当前的现实日期就是{nowdt(TZ):%Y-%m-%d}, 以上搜索结果并非预测中的结果",
-                )
-                config["completions"] = remove_tool(config["completions"], "get_online_search_result")
-                return await get_gpt_response(config, retry, **kwargs)
-        content = glom(response, "choices.0.message.content", default="") or ""
-        reasoning = glom(response, "choices.0.message.reasoning", default="") or ""
+        content = glom(choice, "message.content", default="") or ""
+        reasoning = glom(choice, "message.reasoning", default="") or ""
         primary_model = glom(config, "completions.model", default="") or ""
         used_model = glom(response, "model", default="") or ""
-        res = {"content": content.strip(), "reasoning": reasoning.strip(), "reasoning_model": used_model, "bot_msg_prefix": config["bot_msg_prefix"]}
+        response = {"content": content.strip(), "reasoning": reasoning.strip(), "reasoning_model": used_model, "bot_msg_prefix": config["bot_msg_prefix"]}
         if not (used_model in primary_model or primary_model in used_model):
             # do not use `!=` to compare. (deepseek/deepseek-r1:free != deepseek/deepseek-r1,  gpt-4o != gpt-4o-2024-07-18)
             used_model = used_model.split("/")[-1]
             logger.warning(f"Fallback model {primary_model} -> {used_model}")
             if ENABLE.GPT_WARN_FALLBACK:
-                res["bot_msg_prefix"] = res["bot_msg_prefix"].replace(f"({BOT_TIPS})", f"(发生回退: {used_model})\n({BOT_TIPS})")
+                response["bot_msg_prefix"] = response["bot_msg_prefix"].replace(config["friendly_name"], used_model)
     except Exception as e:
-        logger.error(f"GPT failed: {e}")
+        logger.error(f"Parse  GPT response failed: {e}")
         raise
-    return res
+    return response
src/llm/summary.py
@@ -11,7 +11,7 @@ from pyrogram.types import Message
 from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, cache
 from llm.contexts import combine_consecutive_role_contexts, simplify_text_contents
 from llm.models import openrouter_hook
-from llm.response import get_gpt_response
+from llm.response import send_to_gpt
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -89,7 +89,7 @@ async def ai_summary(client: Client, message: Message, **kwargs):
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    response = await get_gpt_response(config, **kwargs)
+    response = await send_to_gpt(config, **kwargs)
     if texts := response.get("content"):
         logger.debug(response)
         await send2tg(client, message, texts=texts.strip("`"), **kwargs)
src/llm/tool_scheme.py
@@ -1,11 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-
-from config import ENABLE, TZ
-from llm.utils import change_system_prompt
-from utils import nowdt
-
 ONLINE_SEARCH = {
     "type": "function",
     "function": {
@@ -20,36 +15,3 @@ ONLINE_SEARCH = {
         },
     },
 }
-
-
-def get_tools(contexts: list[dict]) -> tuple[dict, list[dict]]:
-    """Get tools for GPT.
-
-    Returns: (tools_params, contexts)
-        tools_params:
-            {
-                "tools": [{tool_1}, {tool_2}, ...],  # list of dict
-                "tool_choice": "auto",
-            }
-    """
-    tools = []
-    if ENABLE.GPT_ONLINE_SEARCH:
-        tools = [ONLINE_SEARCH]
-        contexts = change_system_prompt(contexts, f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}")
-    tools_params = {"tools": tools, "tool_choice": "auto"} if tools else {}
-    return tools_params, contexts
-
-
-def remove_tool(params: dict, tool_name: str) -> dict:
-    """Remove tool from contexts.
-
-    Returns: list[dict]
-    """
-    keep_tools = [tool for tool in params.get("tools", []) if tool.get("function", {}).get("name") != tool_name]
-
-    if keep_tools:
-        params["tools"] = keep_tools
-    else:
-        params.pop("tools", None)
-        params.pop("tool_choice", None)
-    return params
src/llm/tool_call.py → src/llm/tools.py
@@ -6,11 +6,15 @@ from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 
-from config import GPT, PROXY, TOKEN
+from config import ENABLE, GPT, PROXY, TOKEN, TZ
+from llm.tool_scheme import ONLINE_SEARCH
+from llm.utils import change_system_prompt
 from networking import hx_req
+from utils import nowdt
 
 
 async def get_online_search_result(query: str) -> list[dict]:
+    results = []
     if GPT.PRIMARY_SEARCH_ENGINE == "google":
         results = await google_search(query)
         if not results:
@@ -19,10 +23,12 @@ async def get_online_search_result(query: str) -> list[dict]:
         results = await glm_search(query)
         if not results:
             return await google_search(query)
-    return []
+    return results
 
 
 async def google_search(query: str) -> list[dict]:
+    if not (TOKEN.GOOGLE_SEARCH_API_KEY and TOKEN.GOOGLE_SEARCH_CX):
+        return []
     try:
         url = f"https://www.googleapis.com/customsearch/v1?key={TOKEN.GOOGLE_SEARCH_API_KEY}&cx={TOKEN.GOOGLE_SEARCH_CX}&q={query}"
         response = await hx_req(url, proxy=PROXY.GOOGLE_SEARCH, check_keys=["items"])
@@ -40,6 +46,8 @@ async def google_search(query: str) -> list[dict]:
 
 
 async def glm_search(query: str) -> list[dict]:
+    if not (GPT.GLM_API_KEY and GPT.GLM_BASE_URL):
+        return []
     try:
         client = AsyncOpenAI(
             api_key=GPT.GLM_API_KEY,
@@ -65,3 +73,47 @@ async def glm_search(query: str) -> list[dict]:
     except Exception as e:
         logger.error(e)
     return []
+
+
+def add_tools(params: dict) -> dict:
+    """Add tools for GPT.
+
+    Args:
+        params: dict, params for `openai.chat.completions.create()`
+
+    Returns:
+        tools_params:
+            {
+                "tools": [{tool_1}, {tool_2}, ...],  # list of dict
+                "tool_choice": "auto",
+            }
+    """
+    tools = []
+    if ENABLE.GPT_ONLINE_SEARCH:
+        tools = [ONLINE_SEARCH]
+        system_prompt = f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}"
+        params["messages"] = change_system_prompt(params["messages"], system_prompt)
+    if tools:
+        params["tools"] = tools
+        params["tool_choice"] = "auto"
+    return params
+
+
+def remove_tool(params: dict, tool_name: str) -> dict:
+    """Remove tool from contexts.
+
+    Returns: list[dict]
+    """
+    if tool_name.upper() == "ALL":
+        params.pop("tools", None)
+        params.pop("tool_choice", None)
+        return params
+
+    keep_tools = [tool for tool in params.get("tools", []) if tool.get("function", {}).get("name") != tool_name]
+
+    if keep_tools:
+        params["tools"] = keep_tools
+    else:
+        params.pop("tools", None)
+        params.pop("tool_choice", None)
+    return params
src/config.py
@@ -135,11 +135,14 @@ class COOKIE:  # See: https://github.com/easychen/CookieCloud
     CLOUD_PASS = os.getenv("COOKIE_CLOUD_PASS", "")
 
 
-class GPT:
+class GPT:  # see `llm/README.md`
     TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
     IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
     VIDEO_MODEL = os.getenv("GPT_VIDEO_MODEL", "glm-4v-plus")
-    FALLBACK_MODELS = os.getenv("GPT_FALLBACK_MODELS", "")  # comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
+    TOOLS_MODEL = os.getenv("GPT_TOOLS_MODEL", "gpt-4o-mini")  # this model should be fast and cheap
+    # comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
+    FALLBACK_MODELS = os.getenv("GPT_FALLBACK_MODELS", "")
+    FALLBACK_TOOLS_MODELS = os.getenv("GPT_FALLBACK_TOOLS_MODELS", "")  # comma separated fallback tool models for OpenRouter
     TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "gpt-4o")  # custom name
     IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "gpt-4o")
     VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "glm-4v-plus")
@@ -160,6 +163,8 @@ class GPT:
     IMAGE_BASE_URL = os.getenv("GPT_IMAGE_BASE_URL", "https://api.openai.com/v1")
     VIDEO_API_KEY = os.getenv("GPT_VIDEO_API_KEY", "")
     VIDEO_BASE_URL = os.getenv("GPT_VIDEO_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
+    TOOLS_API_KEY = os.getenv("GPT_TOOLS_API_KEY", "")
+    TOOLS_BASE_URL = os.getenv("GPT_TOOLS_BASE_URL", "https://api.openai.com/v1")
     TOKEN_ENCODING = os.getenv("GPT_TOKEN_ENCODING", "o200k_base")  # https://github.com/openai/tiktoken
     MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))