Commit 3e2fb8b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-10 09:46:58
feat(gpt): add reasoning response and retry function
1 parent f319436
src/llm/contexts.py
@@ -96,7 +96,7 @@ async def single_context(client: Client, message: Message) -> dict:
                     media.append({"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}})
                 # elif info["mtype"] == "video":
                 #     media.append({"type": "video_url", "video_url": {"url": b64}})
-                elif info["mtype"] == "document" and info["mime_type"] == "text/plain":
+                elif info["mtype"] == "document" and info["mime_type"] == "text/plain" and not info["file_name"].startswith("GPT-Reasoning"):  # skip GPT reasoning
                     media.append({"type": "text", "text": res.getvalue().decode("utf-8")})
                 else:
                     logger.warning(f"Unsupported message type: {info['mtype']}")
@@ -107,7 +107,7 @@ async def single_context(client: Client, message: Message) -> dict:
                     media.append({"type": "image_url", "image_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
                 # elif info["mtype"] == "video":
                 #     media.append({"type": "video_url", "video_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
-                elif info["mtype"] == "document" and info["mime_type"] == "text/plain":
+                elif info["mtype"] == "document" and info["mime_type"] == "text/plain" and not info["file_name"].startswith("GPT-Reasoning"):  # skip GPT reasoning
                     media.append({"type": "text", "text": Path(path).read_text()})
                     Path(path).unlink(missing_ok=True)
                 else:
src/llm/gpt.py
@@ -1,10 +1,12 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+from pathlib import Path
+
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ENABLE, GPT, PREFIX, cache
+from config import DOWNLOAD_DIR, ENABLE, GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_model_config_with_contexts, get_model_type
 from llm.response import get_gpt_response
@@ -13,6 +15,7 @@ from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import equal_prefix, startswith_prefix
+from utils import rand_number, save_txt
 
 HELP = f"""🤖**GPT对话**
 当前模型:
@@ -77,8 +80,12 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
     response = await get_gpt_response(config, **kwargs)
+    media = [{"document": save_txt(reasoning, f"{DOWNLOAD_DIR}/GPT-Reasoning-{rand_number()}.txt")}] if (reasoning := response.get("reasoning")) else []
+    if content := response.get("content"):
+        texts = f"{config['bot_msg_prefix']}\n\n{content}"
+        logger.debug(texts)
+        await send2tg(client, message, texts=texts, media=media, **kwargs)
+        await modify_progress(del_status=True, **kwargs)
     llm_cleanup_files(config["completions"]["messages"])
-    texts = f"{config['bot_msg_prefix']}\n\n{response}"
-    logger.debug(texts)
-    await send2tg(client, message, texts=texts, **kwargs)
-    await modify_progress(del_status=True, **kwargs)
+    if media:
+        Path(media[0]["document"]).unlink(missing_ok=True)
src/llm/models.py
@@ -79,9 +79,10 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
 
 def openrouter_hook(base_url: str) -> dict:
     """Add special parameters for OpenRouter."""
-    if "openrouter.ai" not in base_url:
+    if "openrouter" not in base_url:
         return {}
     params = {}
+    params |= {"extra_body": {"include_reasoning": True}}
     if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
         params |= {"extra_body": {"models": models}}
     return params
src/llm/response.py
@@ -1,36 +1,74 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import contextlib
 import json
 
 from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI
 
-from config import TZ
+from config import GPT, TZ
 from llm.tool_call import get_online_search_result
 from llm.utils import change_system_prompt
 from messages.progress import modify_progress
 from utils import nowdt
 
 
-async def get_gpt_response(config: dict, **kwargs) -> str:
-    """Get GPT response for text model."""
+async def get_gpt_response(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
+    """Get GPT response for text model.
+
+    Returns:
+        {"content": str, "reasoning": str}
+    """
     try:
         openai = AsyncOpenAI(**config["client"])
         logger.trace(config)
         resp = await openai.chat.completions.create(**config["completions"])
-        return await parse_tool_call(openai, config, resp.model_dump(), **kwargs)
+        resp = resp.model_dump()
+        error = await parse_error(resp, retry, **kwargs)
+        if error["retry"]:
+            return await get_gpt_response(config, retry=retry + 1, **kwargs)
+        if not error["error"]:
+            return await parse_tool_call(openai, config, resp, **kwargs)
     except Exception as e:
-        error = f"🤖{config['friendly_name']}请求失败, 请稍后重试.\n{e}"
-        logger.error(f"GPT request failed: {e}")
+        error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
+        logger.error(error)
         await modify_progress(text=error, force_update=True, **kwargs)
-        return error
+        if retry < GPT.MAX_RETRY:
+            return await get_gpt_response(config, retry=retry + 1, **kwargs)
+    return {"content": "", "reasoning": ""}
+
+
+async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
+    """Parse GPT error.
+
+    Returns:
+        {"error": bool, "retry": bool}
+    """
+    error_result = {"error": False, "retry": False}
+    error_code = glom(resp, "error.code", default=0)
+    error_msg = ""
+    with contextlib.suppress(Exception):
+        metadata = glom(resp, "error.metadata.raw", default={})
+        error_msg = glom(json.loads(metadata), "error.message", default="")
+    if error_code != 0:
+        logger.warning(resp)
+        error_result["error"] = True
+        await modify_progress(text=f"[{error_code}] {error_msg}\n重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}", force_update=True, **kwargs)
+        if retry < GPT.MAX_RETRY:
+            error_result["retry"] = True
+    return error_result
+
 
+async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **kwargs) -> dict[str, str]:
+    """Parse tool call.
 
-async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **kwargs) -> str:
+    Returns:
+        {"content": str, "reasoning": str}
+    """
     choice = glom(response, "choices.0", default=[])
     if not choice:
-        return ""
+        return {"content": "", "reasoning": ""}
     logger.debug(response)
     try:
         if tool_call := glom(choice, "message.tool_calls.0", default={}):
@@ -51,7 +89,9 @@ async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **k
                 )
                 response = resp.model_dump()
                 logger.debug(response)
-        return glom(response, "choices.0.message.content", default="")
+        content = glom(response, "choices.0.message.content", default="") or ""
+        reasoning = glom(response, "choices.0.message.reasoning", default="") or ""
     except Exception as e:
         logger.error(f"GPT failed: {e}")
         raise
+    return {"content": content.strip(), "reasoning": reasoning.strip()}
src/llm/summary.py
@@ -90,9 +90,10 @@ async def ai_summary(client: Client, message: Message, **kwargs):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
     response = await get_gpt_response(config, **kwargs)
-    logger.debug(response)
-    await send2tg(client, message, texts=response.strip("`"), **kwargs)
-    await modify_progress(del_status=True, **kwargs)
+    if texts := response.get("content"):
+        logger.debug(response)
+        await send2tg(client, message, texts=texts.strip("`"), **kwargs)
+        await modify_progress(del_status=True, **kwargs)
 
 
 def get_summay_model(contexts: list[dict]) -> dict:
src/config.py
@@ -142,9 +142,9 @@ class GPT:
     SEARCH_API_KEY = os.getenv("GPT_SEARCH_API_KEY", "")  # online search (currently, we use GLM)
     SEARCH_BASE_URL = os.getenv("GPT_SEARCH_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
     SEARCH_MODEL = os.getenv("GPT_SEARCH_MODEL", "web-search-pro")
-    TEXT_TIMEOUT = os.getenv("GPT_TEXT_TIMEOUT", "15")
-    IMAGE_TIMEOUT = os.getenv("GPT_IMAGE_TIMEOUT", "30")
-    VIDEO_TIMEOUT = os.getenv("GPT_VIDEO_TIMEOUT", "30")
+    TEXT_TIMEOUT = os.getenv("GPT_TEXT_TIMEOUT", "120")
+    IMAGE_TIMEOUT = os.getenv("GPT_IMAGE_TIMEOUT", "120")
+    VIDEO_TIMEOUT = os.getenv("GPT_VIDEO_TIMEOUT", "120")
     TEMPERATURE = os.getenv("GPT_TEMPERATURE", "1.0")
     HISTORY_CONTEXT = os.getenv("GPT_HISTORY_CONTEXT", "20")  # 最多携带多少条历史消息
     MEDIA_FORMAT = os.getenv("GPT_MEDIA_FORMAT", "base64")  # base64 or http
@@ -156,6 +156,7 @@ class GPT:
     VIDEO_API_KEY = os.getenv("GPT_VIDEO_API_KEY", "")
     VIDEO_BASE_URL = os.getenv("GPT_VIDEO_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
     TOKEN_ENCODING = os.getenv("GPT_TOKEN_ENCODING", "o200k_base")  # https://github.com/openai/tiktoken
+    MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
 
 
 class TID:
src/utils.py
@@ -254,6 +254,13 @@ def ascii_to_unicode(text: str) -> str:
     return bytes(str(text), "ascii").decode("unicode_escape")
 
 
+def save_txt(text: str, path: Path | str | None = None) -> str:
+    if path is None:
+        path = Path(DOWNLOAD_DIR) / f"{rand_string()}.txt"
+    Path(path).write_text(text)
+    return Path(path).as_posix()
+
+
 def check_data(text: str, check_keys: list[str] | None = None, check_kv: dict | None = None):
     """Check if data contains required keys and key-value pairs.