Commit 285cd5b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-13 14:04:17
feat(gpt): support custom reasoning models
1 parent 23c1a05
Changed files (3)
src/llm/models.py
@@ -104,11 +104,12 @@ def openrouter_hook(base_url: str, *, for_tools: bool = False) -> dict:
 
 def model_hook(params: dict) -> dict:
     """Add parameters for special models."""
-    # hook for deepseek-r1.
+    # hook for Reasoning Models.
     # Ref: https://github.com/deepseek-ai/DeepSeek-R1/tree/97612c28d06139aa25bb8bca5d632e1fccd70ffd?tab=readme-ov-file#usage-recommendations
     # Ref: https://linux.do/t/topic/408247
     model = params.get("model", "").lower()
-    if any(x in model for x in ["deepseek-r1", "think", "o1", "o3"]):
+    reasoning_models = [x.strip() for x in GPT.REASONING_MODELS.split(",") if x.strip()]
+    if any(x in model for x in reasoning_models):
         params["messages"] = change_system_prompt(
             context=params.get("messages", []),
             prompt="In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}",
src/llm/utils.py
@@ -117,4 +117,10 @@ def extract_reasoning(text: str) -> tuple[str, str]:
     if matched := re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL):
         reasoning = matched.group(1)
         text = re.sub(r"<thinking>(.*?)</thinking>", "", text, count=1, flags=re.DOTALL)
+
+    # Reverse engineered Web API
+    if matched := re.search(r"^>?正在推理(.*?)(已推理,持续.*?)\n\n(.*)", text, re.DOTALL):  # noqa: RUF001
+        reasoning = matched.group(1)
+        text = matched.group(3)
+
     return reasoning.strip(), text.strip().removeprefix("{content}").strip()
src/config.py
@@ -167,6 +167,8 @@ class GPT:  # see `llm/README.md`
     TOOLS_BASE_URL = os.getenv("GPT_TOOLS_BASE_URL", "https://api.openai.com/v1")
     TOKEN_ENCODING = os.getenv("GPT_TOKEN_ENCODING", "o200k_base")  # https://github.com/openai/tiktoken
     MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
+    # comma separated reasoning models, add system prompt to the models to ensure the output format.
+    REASONING_MODELS = os.getenv("GPT_REASONING_MODELS", "deepseek-r1,o1,o3")
     # /gemini command
     GEMINI_MODEL = os.getenv("GPT_GEMINI_MODEL", "gemini-2.0-flash")
     GEMINI_MODEL_NAME = os.getenv("GPT_GEMINI_MODEL_NAME", "Gemini-2.0-Flash")