Commit 03f6c39

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-11 16:56:57
feat(gpt): make `deepseek-r1` always reasoning
1 parent 12d096e
src/llm/contexts.py
@@ -73,10 +73,11 @@ async def single_context(client: Client, message: Message) -> dict:
         return re.sub(rf"(.*?){BOT_TIPS}\)", "", text.removeprefix(PREFIX.GPT), flags=re.DOTALL).strip()
 
     info = parse_msg(message, silent=True)
-    role = "assistant" if BOT_TIPS in info["text"] else "user"
+    role = "assistant" if f"{BOT_TIPS})" in info["text"] else "user"
+    texts = clean_text(info["text"])
     # only text
-    if info["mtype"] == "text" and (text := clean_text(info["text"])):
-        return {"role": role, "content": [{"type": "text", "text": text}]}
+    if info["mtype"] == "text" and texts:
+        return {"role": role, "content": [{"type": "text", "text": texts}]}
 
     if info["mtype"] not in ["photo", "voice", "video", "document"]:
         return {}
@@ -99,7 +100,7 @@ async def single_context(client: Client, message: Message) -> dict:
                 elif info["mtype"] == "document" and info["mime_type"] == "text/plain" and not info["file_name"].startswith("GPT-Reasoning"):  # skip GPT reasoning
                     media.append({"type": "text", "text": res.getvalue().decode("utf-8")})
                 else:
-                    logger.warning(f"Unsupported message type: {info['mtype']}")
+                    logger.warning(f"Skip message type: {info['mtype']}")
             else:
                 path: str = await client.download_media(msg)  # type: ignore
                 logger.debug(f"Downloaded GPT media: {path}")
@@ -112,8 +113,8 @@ async def single_context(client: Client, message: Message) -> dict:
                     Path(path).unlink(missing_ok=True)
                 else:
                     logger.warning(f"Unsupported message type: {info['mtype']}")
-            if caption := info["text"]:
-                media.append({"type": "text", "text": caption})
+            if texts:
+                media.append({"type": "text", "text": texts})
         except Exception as e:
             logger.warning(f"Download media from message failed: {e}")
             continue
src/llm/models.py
@@ -6,7 +6,7 @@ from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
 from config import GPT, PROXY
-from llm.utils import BOT_TIPS
+from llm.utils import BOT_TIPS, change_system_prompt
 from messages.parser import parse_msg
 
 
@@ -48,7 +48,6 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
     timeouts = {"text": GPT.TEXT_TIMEOUT, "image": GPT.IMAGE_TIMEOUT, "video": GPT.VIDEO_TIMEOUT}
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
     urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL, "video": GPT.VIDEO_BASE_URL}
-    model = models[model_type]
 
     # setup configs
     # params for OpenAI client
@@ -60,9 +59,9 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
     }
 
     # params for `openai.chat.completions.create()`
-    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+    completions = {"model": models[model_type], "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+    completions = model_hook(completions)
     completions |= openrouter_hook(client["base_url"])
-
     return {
         "friendly_name": model_names[model_type],
         "bot_msg_prefix": f"🤖**{model_names[model_type]}**: ({BOT_TIPS})",
@@ -84,3 +83,17 @@ def openrouter_hook(base_url: str, *, for_tools: bool = False) -> dict:
         if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
             params["extra_body"] |= {"models": models}
     return params
+
+
+def model_hook(params: dict) -> dict:
+    """Add parameters for special models."""
+    # hook for deepseek-r1.
+    # Ref: https://github.com/deepseek-ai/DeepSeek-R1/tree/97612c28d06139aa25bb8bca5d632e1fccd70ffd?tab=readme-ov-file#usage-recommendations
+    # Ref: https://linux.do/t/topic/408247
+    if "deepseek-r1" in params.get("model", "").lower():
+        params["messages"] = change_system_prompt(
+            context=params.get("messages", []),
+            prompt="In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}",
+            method="prepend",
+        )
+    return params
src/llm/response.py
@@ -43,9 +43,11 @@ async def merge_tools_response(config: dict, **kwargs) -> dict:
         logger.debug(f"Online search tool call args: {args}")
         await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
         if tool_result := await get_online_search_result(**args):
+            current_time = nowdt(TZ).replace(microsecond=0).isoformat()
             config["completions"]["messages"] = change_system_prompt(
                 config["completions"]["messages"],
-                f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [exaplme.com](https://www.exaplme.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}\n请再次注意, 当前的现实日期就是{nowdt(TZ):%Y-%m-%d}, 以上搜索结果并非预测中的结果",
+                f"于{current_time}进行了一次网络搜索, 网络搜索结果如下:\n{tool_result}\n请以Markdown格式给出回复时的参考链接\n请注意, 现在的真实时间就是{current_time}, 以上搜索结果并非预测中的结果",
+                method="append",
             )
             return config
     except Exception as e:
src/llm/tools.py
@@ -92,7 +92,7 @@ def add_tools(params: dict) -> dict:
     if ENABLE.GPT_ONLINE_SEARCH:
         tools = [ONLINE_SEARCH]
         system_prompt = f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}"
-        params["messages"] = change_system_prompt(params["messages"], system_prompt)
+        params["messages"] = change_system_prompt(params["messages"], system_prompt, method="overwrite")
     if tools:
         params["tools"] = tools
         params["tool_choice"] = "auto"
src/llm/utils.py
@@ -54,13 +54,24 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
         return 0
 
 
-def change_system_prompt(context: list[dict], prompt: str) -> list[dict]:
+def change_system_prompt(context: list[dict], prompt: str, method: str = "overwrite") -> list[dict]:
     if not context:
         return [{"role": "system", "content": prompt}]
-    if context[0].get("role") == "system":
-        context[0]["content"] = prompt
+    if method not in ["overwrite", "prepend", "append"]:
+        logger.warning(f"Invalid method of `change_system_prompt`: {method}")
         return context
-    context.insert(0, {"role": "system", "content": prompt})
+    if method == "overwrite":
+        if context[0].get("role") == "system":
+            context[0]["content"] = prompt
+        else:
+            context.insert(0, {"role": "system", "content": prompt})
+    elif method == "prepend":
+        context.insert(0, {"role": "system", "content": prompt})
+    elif method == "append":  # append to the end of the system prompts
+        for idx, item in enumerate(context):
+            if item.get("role") != "system":
+                context.insert(idx, {"role": "system", "content": prompt})
+                break
     return context