Commit ad69ddb

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-07 10:17:09
feat: add GPT fallback models for OpenRouter
1 parent e902978
src/llm/models.py
@@ -60,3 +60,11 @@ def get_model_with_contexts(model_type: str, contexts: list[dict]) -> tuple[dict
     contexts = simplify_text_contents(contexts)
     logger.trace(config)
     return config, contexts
+
+
+def get_fallback_models(base_url: str) -> dict:
+    """Get fallback models for OpenRouter."""
+    models = []
+    if "openrouter.ai" in base_url:
+        models = [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]
+    return {"extra_body": {"models": models}} if models else {}
src/llm/response.py
@@ -7,8 +7,9 @@ from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 from openai.types.chat import ChatCompletion
 
 from config import PROXY, TZ
+from llm.models import get_fallback_models
 from llm.tool_call import get_online_search_result
-from llm.tool_scheme import ONLINE_SEARCH
+from llm.tool_scheme import get_tools
 from llm.utils import change_system_prompt
 from messages.progress import modify_progress
 from utils import nowdt
@@ -16,7 +17,6 @@ from utils import nowdt
 
 async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
     """Get GPT response for text model."""
-    logger.trace(contexts)
     try:
         openai = AsyncOpenAI(
             api_key=config["key"],
@@ -24,14 +24,11 @@ async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
             timeout=config["timeout"],
             http_client=DefaultAsyncHttpxClient(proxy=PROXY.GPT),
         )
-        contexts = change_system_prompt(contexts, f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}")
-        resp = await openai.chat.completions.create(
-            model=config["model"],
-            messages=contexts,  # type: ignore
-            temperature=config["temperature"],
-            tools=ONLINE_SEARCH,  # type: ignore
-            tool_choice="auto",
-        )
+        params = {"model": config["model"], "messages": contexts, "temperature": config["temperature"]}
+        params |= get_tools()
+        params |= get_fallback_models(config["base_url"])
+        logger.trace(params)
+        resp = await openai.chat.completions.create(**params)
         return await parse_tool_call(openai, config, resp, contexts, **kwargs)
     except Exception as e:
         error = f"🤖{config['friendly_name']}请求失败, 请稍后重试.\n{e}"
@@ -59,15 +56,15 @@ async def parse_tool_call(
                 logger.debug(f"Online search tool call args: {args}")
                 await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
                 tool_result = await get_online_search_result(**args)
-            contexts = change_system_prompt(
-                contexts,
-                f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [nytimes.com](https://nytimes.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
-            )
-            response = await openai.chat.completions.create(
-                model=config["model"],
-                messages=contexts,  # type: ignore
-                temperature=config["temperature"],
-            )
+                contexts = change_system_prompt(
+                    contexts,
+                    f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [nytimes.com](https://nytimes.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
+                )
+                response = await openai.chat.completions.create(
+                    model=config["model"],
+                    messages=contexts,  # type: ignore
+                    temperature=config["temperature"],
+                )
         if choices := response.model_dump().get("choices", []):
             return choices[0].get("message", {}).get("content", "")
     except Exception as e:
src/llm/tool_scheme.py
@@ -1,16 +1,33 @@
-ONLINE_SEARCH = [
-    {
-        "type": "function",
-        "function": {
-            "name": "get_online_search_result",
-            "description": "",
-            "parameters": {
-                "type": "object",
-                "properties": {
-                    "query": {"description": "搜索关键词", "type": "string"},
-                },
-                "required": ["query"],
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from config import ENABLE
+
+ONLINE_SEARCH = {
+    "type": "function",
+    "function": {
+        "name": "get_online_search_result",
+        "description": "",
+        "parameters": {
+            "type": "object",
+            "properties": {
+                "query": {"description": "联网搜索关键词", "type": "string"},
             },
+            "required": ["query"],
         },
     },
-]
+}
+
+
+def get_tools() -> dict:
+    """Get tools for GPT.
+
+    Returns: {
+                "tools": [{tool_1}, {tool_2}, ...],  # list of dict
+                "tool_choice": "auto",
+            }
+    """
+    tools = []
+    if ENABLE.GPT_ONLINE_SEARCH:
+        tools = [ONLINE_SEARCH]
+    return {"tools": tools, "tool_choice": "auto"} if tools else {}
src/config.py
@@ -32,6 +32,7 @@ class ENABLE:
     CRONTAB = os.getenv("ENABLE_CRONTAB", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     DOUYIN = os.getenv("ENABLE_DOUYIN", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     GPT = os.getenv("ENABLE_GPT", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    GPT_ONLINE_SEARCH = os.getenv("ENABLE_GPT_ONLINE_SEARCH", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     INSTAGRAM = os.getenv("ENABLE_INSTAGRAM", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     OCR = os.getenv("ENABLE_OCR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     PRICE = os.getenv("ENABLE_PRICE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -132,6 +133,7 @@ class GPT:
     TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
     IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
     VIDEO_MODEL = os.getenv("GPT_VIDEO_MODEL", "glm-4v-plus")
+    FALLBACK_MODELS = os.getenv("GPT_FALLBACK_MODELS", "")  # comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
     TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "gpt-4o")  # custom name
     IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "gpt-4o")
     VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "glm-4v-plus")