Commit b245389

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-15 03:29:12
refactor(gpt): improve reasoning model
1 parent ecbfb32
src/llm/contexts.py
@@ -10,6 +10,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import GPT, PREFIX
+from llm.prompts import refine_prompts
 from llm.utils import BOT_TIPS
 from messages.parser import parse_msg
 
@@ -47,8 +48,7 @@ async def get_conversation_contexts(client: Client, conversations: list[Message]
     # parse context for each message
     contexts = [await single_context(client, message) for message in conversations]
     contexts = [x for x in contexts if x]  # filter out empty context
-    contexts = combine_consecutive_role_contexts(contexts)
-    contexts = simplify_text_contents(contexts)
+    contexts = refine_prompts(contexts)
 
     return contexts[: int(GPT.HISTORY_CONTEXT)]
 
@@ -121,73 +121,3 @@ async def single_context(client: Client, message: Message) -> dict:
             logger.warning(f"Download media from message failed: {e}")
             continue
     return {"role": role, "content": media}
-
-
-def combine_consecutive_role_contexts(contexts: list[dict]) -> list[dict]:
-    """Combine consecutive user and assistant contexts into one message.
-
-    Some GPT models don't support consecutive user and assistant contexts. (e.g. Hunyuan)
-
-    Args:
-        contexts (list[dict]): [
-            {
-                "role": "user or assistant",
-                "content": [
-                    {'type': 'text', 'text': 'caption this img'},
-                    {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,base64_image'}},
-                    {'type': 'image_url', 'image_url': {'url': 'https://server.com/dir/image.jpg'}},
-                ]
-            }
-        ]
-    """
-    combined_contexts = []
-    for i, msg in enumerate(contexts):
-        if i == 0:
-            combined_contexts.append(msg)
-            continue
-        if msg["role"] == combined_contexts[-1]["role"]:
-            combined_contexts[-1]["content"].extend(msg["content"])
-        else:
-            combined_contexts.append(msg)
-    return combined_contexts
-
-
-def simplify_text_contents(contexts: list[dict]) -> list[dict]:
-    """Simplify the plain text content format.
-
-    Some models do not support this format:
-        [{'text': 'hi', 'type': 'text'}], 'role': 'user'}]
-
-    It only supports:
-        [{'content': 'hi', 'role': 'user'}]
-
-    Args:
-        contexts (list[dict]): [
-            {
-                "role": "user or assistant",
-                "content": [
-                    {'type': 'text', 'text': 'caption this img'},
-                ]
-            }
-        ]
-
-    Returns:
-        list[dict]: [
-            {
-                "role": "user or assistant",
-                "content": "caption this img"
-            }
-        ]
-    """
-    fixed_contexts = []
-    for msg in contexts:
-        if not msg.get("content") or not isinstance(msg.get("content"), list):
-            fixed_contexts.append(msg)
-            continue
-        contents = msg.get("content", [])
-        if all(x.get("type") == "text" for x in contents):
-            msg["content"] = "\n".join([x.get("text") for x in contents])
-            fixed_contexts.append(msg)
-        else:
-            fixed_contexts.append(msg)
-    return fixed_contexts
src/llm/models.py
@@ -6,7 +6,7 @@ from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
 from config import GPT, PROXY
-from llm.utils import change_system_prompt
+from llm.prompts import force_reasoning, refine_prompts
 from messages.parser import parse_msg
 
 
@@ -45,7 +45,6 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
     """
     models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL, "video": GPT.VIDEO_MODEL}
     model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME, "video": GPT.VIDEO_MODEL_NAME}
-    timeouts = {"text": GPT.TEXT_TIMEOUT, "image": GPT.IMAGE_TIMEOUT, "video": GPT.VIDEO_TIMEOUT}
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
     urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL, "video": GPT.VIDEO_BASE_URL}
 
@@ -56,7 +55,7 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
     client = {
         "api_key": apis[model_type],
         "base_url": urls[model_type],
-        "timeout": round(float(timeouts[model_type])),
+        "timeout": round(float(GPT.TIMEOUT)),
         "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
     }
 
@@ -77,6 +76,7 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
     completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
     completions = model_hook(completions)
     completions |= openrouter_hook(client["base_url"])  # this line should be after setting `force_model``
+    completions["messages"] = refine_prompts(completions["messages"])  # final refine after hooks
 
     if force_model != "N/A" and completions.get("extra_body"):  # remove models fallback
         completions["extra_body"].pop("models", None)  # should be after hooks
@@ -104,15 +104,4 @@ def openrouter_hook(base_url: str, *, for_tools: bool = False) -> dict:
 
 def model_hook(params: dict) -> dict:
     """Add parameters for special models."""
-    # hook for Reasoning Models.
-    # Ref: https://github.com/deepseek-ai/DeepSeek-R1/tree/97612c28d06139aa25bb8bca5d632e1fccd70ffd?tab=readme-ov-file#usage-recommendations
-    # Ref: https://linux.do/t/topic/408247
-    model = params.get("model", "").lower()
-    reasoning_models = [x.strip() for x in GPT.REASONING_MODELS.split(",") if x.strip()]
-    if any(x in model for x in reasoning_models):
-        params["messages"] = change_system_prompt(
-            context=params.get("messages", []),
-            prompt="In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}",
-            method="prepend",
-        )
-    return params
+    return force_reasoning(params)
src/llm/prompts.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from loguru import logger
+
+from config import GPT, TZ
+from utils import nowdt
+
+REASONING_PROMPT = "In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}"
+
+
+# ruff: noqa: RUF001
+def modify_prompts(context: list[dict], prompt: str, role: str = "system", method: str = "overwrite") -> list[dict]:
+    if role not in ["system", "user", "assistant"]:
+        logger.warning(f"Invalid method of `modify_prompts`: {method}")
+        return context
+    if not context:
+        return [{"role": role, "content": prompt}]
+    if method not in ["overwrite", "prepend", "append"]:
+        logger.warning(f"Invalid method of `modify_prompts`: {method}")
+        return context
+    if method == "overwrite":
+        if context[0].get("role") == role:
+            context[0]["content"] = prompt
+        else:
+            context.insert(0, {"role": role, "content": prompt})
+    elif method == "prepend":
+        context.insert(0, {"role": role, "content": prompt})
+    elif method == "append":
+        context.append({"role": role, "content": prompt})
+    return context
+
+
+def force_reasoning(params: dict) -> dict:
+    """Force model reasoning.
+
+    Although the official instructions do not add system prompt [1],
+    we found that add the following system prompt is much better [2].
+
+    # Ref-1: https://github.com/deepseek-ai/DeepSeek-R1/tree/ef99616?tab=readme-ov-file#usage-recommendations
+    # Ref-2: https://linux.do/t/topic/408247
+    """
+    model = params.get("model", "").lower()
+    reasoning_models = [x.strip() for x in GPT.REASONING_MODELS.split(",") if x.strip()]
+    # remove previous thinking prompt
+    # params["messages"] = remove_prompt_from_contexts(params["messages"], REASONING_PROMPT)
+
+    if any(x in model for x in reasoning_models):
+        params["messages"] = modify_prompts(
+            params["messages"],
+            prompt=REASONING_PROMPT,
+            role="system",
+            method="prepend",
+        )
+    return params
+
+
+def add_search_results_to_prompts(search_results: list[dict], params: dict) -> dict:
+    """Add search results to contexts.
+
+    # Template: https://github.com/deepseek-ai/DeepSeek-R1/tree/ef99616?tab=readme-ov-file#usage-recommendations
+    """
+    search_msg = ""
+    for idx, result in enumerate(search_results):
+        search_msg += f"[webpage {idx + 1} begin] {result} [webpage {idx + 1} end]\n"
+
+    # modified from DeepSeek's official instructions
+    prompt = f"""# 以下内容是基于用户发送的消息的搜索结果:
+{search_msg}
+在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。
+在回答时,请注意以下几点:
+- 今天是{nowdt(TZ):%Y-%m-%d}。
+- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
+- 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
+- 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
+- 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
+- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
+- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
+- 请在适当的情况下在句子末尾引用上下文。请按照引用编号 [[X]](url) 的格式在答案中对应部分引用上下文。
+- 如果一句话源自多个上下文,请列出所有相关的引用编号,例如[[1]](url1) [[2]](url2),切记不要将引用集中在最后返回,而是在答案对应部分列出。
+- 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
+- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
+
+# 用户消息为:
+"""
+    contexts = params["messages"]
+    # last context is text
+    if isinstance(contexts[-1]["content"], str):
+        contexts[-1]["content"] = f"{prompt}{contexts[-1]['content']}"
+    else:  # list, multi-modality
+        contexts[-1]["content"].insert(0, {"role": "user", "content": prompt})
+    params["messages"] = contexts
+    params["messages"] = refine_prompts(params["messages"])
+    return params
+
+
+def refine_prompts(contexts: list[dict]) -> list[dict]:
+    contexts = combine_consecutive_role(contexts)
+    return simplify_text_prompts(contexts)
+
+
+def combine_consecutive_role(contexts: list[dict]) -> list[dict]:
+    """Combine consecutive user and assistant contexts into one message.
+
+    Some GPT models don't support consecutive user and assistant contexts. (e.g. Hunyuan)
+
+    Args:
+        contexts (list[dict]): [
+            {
+                "role": "user or assistant",
+                "content": [
+                    {'type': 'text', 'text': 'caption this img'},
+                    {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,base64_image'}},
+                    {'type': 'image_url', 'image_url': {'url': 'https://server.com/dir/image.jpg'}},
+                ]
+            }
+        ]
+    """
+    contexts = convert_content_to_list_dict(contexts)
+    combined_contexts = []
+    for i, msg in enumerate(contexts):
+        if i == 0:
+            combined_contexts.append(msg)
+            continue
+        if msg["role"] == combined_contexts[-1]["role"]:
+            combined_contexts[-1]["content"].extend(msg["content"])
+        else:
+            combined_contexts.append(msg)
+    return combined_contexts
+
+
+def simplify_text_prompts(contexts: list[dict]) -> list[dict]:
+    """Simplify the plain text content format.
+
+    Some models do not support this format:
+        [{'text': 'hi', 'type': 'text'}], 'role': 'user'}]
+
+    It only supports:
+        [{'content': 'hi', 'role': 'user'}]
+
+    Args:
+        contexts (list[dict]): [
+            {
+                "role": "user or assistant",
+                "content": [
+                    {'type': 'text', 'text': 'caption this img'},
+                ]
+            }
+        ]
+
+    Returns:
+        list[dict]: [
+            {
+                "role": "user or assistant",
+                "content": "caption this img"
+            }
+        ]
+    """
+    fixed_contexts = []
+    for msg in contexts:
+        if not msg.get("content") or not isinstance(msg.get("content"), list):
+            fixed_contexts.append(msg)
+            continue
+        contents = msg.get("content", [])
+        if all(x.get("type") == "text" for x in contents):
+            msg["content"] = "\n".join([x.get("text") for x in contents])
+            fixed_contexts.append(msg)
+        else:
+            fixed_contexts.append(msg)
+    return fixed_contexts
+
+
+def convert_content_to_list_dict(contexts: list[dict]) -> list[dict]:
+    """Reverse `simplify_text_prompts` function.
+
+    Returns:
+        contexts (list[dict]): [
+            {
+                "role": "user or assistant",
+                "content": [
+                    {'type': 'text', 'text': 'caption this img'},
+                ]
+            }
+        ]
+    """
+    fixed_contexts = []
+    for msg in contexts:
+        if not msg.get("content") or isinstance(msg.get("content"), list):
+            fixed_contexts.append(msg)
+            continue
+        content = msg.get("content", "")
+        if isinstance(content, str):
+            msg["content"] = [{"type": "text", "text": content}]
+            fixed_contexts.append(msg)
+        else:
+            fixed_contexts.append(msg)
+    return fixed_contexts
+
+
+def remove_prompt_from_contexts(contexts: list[dict], prompt: str) -> list[dict]:
+    """Remove the prompt from the contexts."""
+    for msg in contexts:
+        if isinstance(msg.get("content"), str):
+            msg["content"] = msg["content"].replace(prompt, "").strip()
+        elif isinstance(msg.get("content"), list):
+            for content in msg["content"]:
+                if content.get("type") == "text":
+                    content["text"] = content["text"].replace(prompt, "").strip()
+    return contexts
src/llm/response.py
@@ -8,12 +8,13 @@ from glom import glom
 from loguru import logger
 from openai import AsyncOpenAI
 
-from config import ENABLE, GPT, TZ
+from config import ENABLE, GPT
 from llm.models import openrouter_hook
+from llm.prompts import add_search_results_to_prompts
 from llm.tools import add_tools, get_online_search_result
-from llm.utils import beautify_model_name, change_system_prompt, extract_reasoning
+from llm.utils import beautify_model_name, extract_reasoning
 from messages.progress import modify_progress
-from utils import nowdt
+from utils import number_to_emoji
 
 
 async def merge_tools_response(config: dict, **kwargs) -> dict:
@@ -42,12 +43,8 @@ async def merge_tools_response(config: dict, **kwargs) -> dict:
         logger.debug(f"Online search tool call args: {args}")
         await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
         if tool_result := await get_online_search_result(**args):
-            current_time = nowdt(TZ).replace(microsecond=0).isoformat()
-            config["completions"]["messages"] = change_system_prompt(
-                config["completions"]["messages"],
-                f"于{current_time}进行了一次网络搜索, 网络搜索结果如下:\n{tool_result}\n请以Markdown格式给出回复时的参考链接\n请注意, 现在的真实时间就是{current_time}, 以上搜索结果并非预测中的结果",
-                method="append",
-            )
+            config["completions"] = add_search_results_to_prompts(tool_result, config["completions"])
+            config["search_results"] = tool_result  # save search results for future use
             return config
     except Exception as e:
         logger.error(f"Tools_response failed: {e}")
@@ -123,6 +120,7 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
         return response | {"content": "", "reasoning": "", "reasoning_model": ""}
     try:
         content = glom(choice, "message.content", default="") or ""
+        content = add_search_results_to_response(config.get("search_results", []), content)
         reasoning, content = extract_reasoning(content)  # extract reasoning from content (<think>...</think>)
         if not reasoning:
             reasoning = glom(choice, "message.reasoning", default="") or ""
@@ -139,3 +137,16 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
         logger.error(f"Parse  GPT response failed: {e}")
         raise
     return response
+
+
+def add_search_results_to_response(search_results: list[dict], response: str) -> str:
+    """Add search results to response."""
+    if not search_results or not response:
+        return response
+    response = response.strip()
+    for idx, result in enumerate(search_results):
+        title = result.get("title", "")[:15]
+        link = result.get("link", "")
+        if link.startswith("http"):
+            response += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
+    return response.strip()
src/llm/summary.py
@@ -9,8 +9,8 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, cache
-from llm.contexts import combine_consecutive_role_contexts, simplify_text_contents
 from llm.models import openrouter_hook
+from llm.prompts import refine_prompts
 from llm.response import send_to_gpt
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
@@ -194,7 +194,6 @@ async def get_contexts(client: Client, history: list[dict]) -> list[dict]:  # no
         #     content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": f"[{info['mtype']}] {info['text']}".strip()}
         #     user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
     contexts.append({"role": "user", "content": user_contexts})
-    contexts = combine_consecutive_role_contexts(contexts)
-    contexts = simplify_text_contents(contexts)
+    contexts = refine_prompts(contexts)
     logger.trace(contexts)
     return contexts
src/llm/tools.py
@@ -7,8 +7,8 @@ from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 
 from config import ENABLE, GPT, PROXY, TOKEN, TZ
+from llm.prompts import modify_prompts
 from llm.tool_scheme import ONLINE_SEARCH
-from llm.utils import change_system_prompt
 from networking import hx_req
 from utils import nowdt
 
@@ -92,7 +92,8 @@ def add_tools(params: dict) -> dict:
     if ENABLE.GPT_ONLINE_SEARCH:
         tools = [ONLINE_SEARCH]
         system_prompt = f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}"
-        params["messages"] = change_system_prompt(params["messages"], system_prompt, method="overwrite")
+        params["messages"] = modify_prompts(params["messages"], system_prompt, method="overwrite")
+
     if tools:
         params["tools"] = tools
         params["tool_choice"] = "auto"
src/llm/utils.py
@@ -55,27 +55,6 @@ def count_tokens(string: str, encoding_name: str | None = None) -> int:
         return 0
 
 
-def change_system_prompt(context: list[dict], prompt: str, method: str = "overwrite") -> list[dict]:
-    if not context:
-        return [{"role": "system", "content": prompt}]
-    if method not in ["overwrite", "prepend", "append"]:
-        logger.warning(f"Invalid method of `change_system_prompt`: {method}")
-        return context
-    if method == "overwrite":
-        if context[0].get("role") == "system":
-            context[0]["content"] = prompt
-        else:
-            context.insert(0, {"role": "system", "content": prompt})
-    elif method == "prepend":
-        context.insert(0, {"role": "system", "content": prompt})
-    elif method == "append":  # append to the end of the system prompts
-        for idx, item in enumerate(context):
-            if item.get("role") != "system":
-                context.insert(idx, {"role": "system", "content": prompt})
-                break
-    return context
-
-
 def beautify_model_name(name: str) -> str:
     """Beautify model name.
 
src/config.py
@@ -151,9 +151,7 @@ class GPT:  # see `llm/README.md`
     GLM_BASE_URL = os.getenv("GPT_GLM_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
     SEARCH_NUM_RESULTS = os.getenv("GPT_SEARCH_NUM_RESULTS", "5")
     PRIMARY_SEARCH_ENGINE = os.getenv("GPT_PRIMARY_SEARCH_ENGINE", "google")  # google or glm
-    TEXT_TIMEOUT = os.getenv("GPT_TEXT_TIMEOUT", "120")
-    IMAGE_TIMEOUT = os.getenv("GPT_IMAGE_TIMEOUT", "120")
-    VIDEO_TIMEOUT = os.getenv("GPT_VIDEO_TIMEOUT", "120")
+    TIMEOUT = os.getenv("GPT_TIMEOUT", "300")
     TEMPERATURE = os.getenv("GPT_TEMPERATURE", "1.0")
     HISTORY_CONTEXT = os.getenv("GPT_HISTORY_CONTEXT", "20")  # 最多携带多少条历史消息
     MEDIA_FORMAT = os.getenv("GPT_MEDIA_FORMAT", "base64")  # base64 or http
src/utils.py
@@ -136,6 +136,12 @@ def soup_to_text(soup: PageElement) -> str:
     return text
 
 
+def number_to_emoji(num: int | str) -> str:
+    """Convert a number to an emoji."""
+    num = str(num)
+    return {"1": "1️⃣", "2": "2️⃣", "3": "3️⃣", "4": "4️⃣", "5": "5️⃣", "6": "6️⃣", "7": "7️⃣", "8": "8️⃣", "9": "9️⃣", "10": "🔟"}.get(num, "🔢")
+
+
 def stringfy(d: dict) -> dict:
     """Convert dict values to string.