Commit d5c7a89

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-03-17 03:40:31
refactor(gpt): split hooks
1 parent fb434ec
src/llm/contexts.py
@@ -11,7 +11,6 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import GPT, PREFIX
-from llm.prompts import refine_prompts
 from llm.utils import BOT_TIPS
 from messages.parser import parse_msg
 
@@ -48,8 +47,7 @@ async def get_conversation_contexts(client: Client, conversations: list[Message]
     """
     # parse context for each message
     contexts = [await single_context(client, message) for message in conversations]
-    contexts = [x for x in contexts if x]  # filter out empty context
-    contexts = refine_prompts(contexts)
+    contexts = [x for x in contexts if x.get("content")]  # filter out empty context
 
     return contexts[: int(GPT.HISTORY_CONTEXT)]
 
src/llm/gpt.py
@@ -7,7 +7,7 @@ from pyrogram.types import Message
 
 from config import GPT, PREFIX, TEXT_LENGTH, cache
 from llm.contexts import get_conversation_contexts, get_conversations
-from llm.models import get_context_type, get_model_config_with_contexts
+from llm.models import get_context_type, get_gpt_config
 from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
 from llm.tools import merge_tools_response
@@ -44,10 +44,9 @@ def is_gpt_conversation(message: Message) -> bool:
     # is replying to gpt-bot response message?
     if not message.reply_to_message:
         return False
-
     reply_msg = message.reply_to_message
     reply_info = parse_msg(reply_msg, silent=True)
-    model_names = [GPT.OPENAI_MODEL_NAME, GPT.GEMINI_MODEL_NAME, GPT.DEEPSEEK_MODEL_NAME, GPT.QWEN_MODEL_NAME, GPT.DOUBAO_MODEL_NAME]
+    model_names = [GPT.OPENAI_MODEL_NAME, GPT.GEMINI_MODEL_NAME, GPT.DEEPSEEK_MODEL_NAME, GPT.QWEN_MODEL_NAME, GPT.DOUBAO_MODEL_NAME, GPT.TEXT_MODEL_NAME, GPT.IMAGE_MODEL_NAME]
     return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
 
 
@@ -65,12 +64,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     if equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao"]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
-
     if not is_gpt_conversation(message):
         return
 
     # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
-    force_model = "N/A"
+    force_model = "NOT_SET"
     reply_text = ""
     if message.reply_to_message:
         reply_info = parse_msg(message.reply_to_message, silent=True)
@@ -101,10 +99,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
             return
         cache.set(f"gpt-{info['cid']}-{media_group_id}", "1", ttl=120)
+    kwargs["message_info"] = info  # save trigger message info
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)
     contexts = await get_conversation_contexts(client, conversations)
-    config = get_model_config_with_contexts(context_type["type"], contexts, force_model, info)
+    config = get_gpt_config(context_type["type"], contexts, force_model)
     msg = f"🤖**{config['friendly_name']}**: 思考中..."
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
src/llm/hooks.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from config import GPT
+from llm.prompts import refine_prompts
+from utils import unicode_to_ascii
+
+
+def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
+    pre_openrouter_hook(client, completions)
+    pre_helicone_hook(client, message_info)
+    completions["messages"] = refine_prompts(completions["messages"])
+
+
+def pre_openrouter_hook(client: dict, completions: dict) -> None:
+    """Add special parameters for OpenRouter."""
+    if "openrouter" not in client["base_url"]:
+        return
+    if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
+        completions["extra_body"] = {"models": models}
+
+
+def pre_helicone_hook(client: dict, message_info: dict | None) -> None:
+    """Add special parameters for helicone gateway."""
+    if not GPT.HELICONE_API_KEY:
+        return
+    headers = client.get("default_headers", {})
+    headers |= {
+        "Helicone-Auth": f"Bearer {GPT.HELICONE_API_KEY}",
+    }
+    message_info = message_info or {}
+    if chat_title := message_info.get("ctitle"):
+        headers |= {"Helicone-Property-Chat": unicode_to_ascii(chat_title), "Helicone-Property-ChatID": str(message_info["cid"])}
+    if user_name := message_info.get("full_name"):
+        headers |= {"Helicone-User-Id": unicode_to_ascii(user_name), "Helicone-Property-User": str(message_info["uid"])}
+    client |= {"default_headers": headers}
src/llm/models.py
@@ -5,9 +5,7 @@ from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
 from config import GPT, PREFIX, PROXY
-from llm.prompts import refine_prompts
 from messages.parser import parse_msg
-from utils import unicode_to_ascii
 
 
 def get_context_type(conversations: list[Message]) -> dict:
@@ -28,8 +26,8 @@ def get_context_type(conversations: list[Message]) -> dict:
     return res
 
 
-def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_model: str = "N/A", message_info: dict | None = None) -> dict:
-    """Get GPT model config based on contexts, and return the config and adjusted contexts.
+def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "NOT_SET") -> dict:
+    """Get GPT configurations.
 
     contexts:
     [
@@ -42,56 +40,6 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
       }
     ]
     """
-    client, model, model_name = align_with_force_model(model_type, force_model)
-
-    # params for `openai.chat.completions.create()`
-    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
-    hooks(client, completions, message_info)  # this line should be after setting `force_model``
-    completions["messages"] = refine_prompts(completions["messages"])  # final refine after hooks
-
-    if force_model != "N/A" and completions.get("extra_body"):  # remove models fallback
-        completions["extra_body"].pop("models", None)  # should be after hooks
-    return {
-        "friendly_name": model_name,
-        "client": client,
-        "completions": completions,
-    }
-
-
-def hooks(client: dict, completions: dict, message_info: dict | None = None):
-    openrouter_hook(client, completions)
-    helicone_hook(client, message_info)
-
-
-def openrouter_hook(client: dict, completions: dict) -> None:
-    """Add special parameters for OpenRouter."""
-    if "openrouter" not in client["base_url"]:
-        return
-    if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
-        completions["extra_body"] = {"models": models}
-
-
-def helicone_hook(client: dict, message_info: dict | None) -> None:
-    """Add special parameters for helicone gateway."""
-    if not GPT.HELICONE_API_KEY:
-        return
-    headers = client.get("default_headers", {})
-    headers |= {
-        "Helicone-Auth": f"Bearer {GPT.HELICONE_API_KEY}",
-    }
-    message_info = message_info or {}
-    if chat_title := message_info.get("ctitle"):
-        headers |= {"Helicone-Property-Chat": unicode_to_ascii(chat_title), "Helicone-Property-ChatID": str(message_info["cid"])}
-    if user_name := message_info.get("full_name"):
-        headers |= {"Helicone-User-Id": unicode_to_ascii(user_name), "Helicone-Property-User": str(message_info["uid"])}
-    client |= {"default_headers": headers}
-
-
-def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[dict, str, str]:
-    """Align the model with the modalities if force_model is specified.
-
-    For example, user use `/ds` to reply an image, but the model only support text, so we need to use switch to image model.
-    """
     models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL, "video": GPT.VIDEO_MODEL}
     model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME, "video": GPT.VIDEO_MODEL_NAME}
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
@@ -99,16 +47,17 @@ def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[d
 
     model = models[model_type]
     model_name = model_names[model_type]
-    if force_model == "N/A":
-        force_model = model
+    force_model = model if force_model == "NOT_SET" else force_model
+
     # params for OpenAI client
-    client = {
+    client = {  # this config is based on model type (text or image)
         "api_key": apis[model_type],
         "base_url": urls[model_type],
         "timeout": round(float(GPT.TIMEOUT)),
         "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
     }
 
+    # align with force model
     model_factory = {
         GPT.OPENAI_MODEL: {"api_key": GPT.OPENAI_API_KEY, "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
         GPT.GEMINI_MODEL: {"api_key": GPT.GEMINI_API_KEY, "base_url": GPT.GEMINI_BASE_URL, "model_name": GPT.GEMINI_MODEL_NAME},
@@ -122,19 +71,29 @@ def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[d
 
     force_model_name = force_model_config.get("model_name", model_name)
     force_model_config.pop("model_name", None)
-    if model_type == "text":  # respect the force model
-        client |= force_model_config
-        return client, force_model, force_model_name
-
-    if model_type == "image" and (  # check capabilities
-        (force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
-        or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
-        or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
-        or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
-        or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
-        or (force_model == GPT.SUMMARY_MODEL and GPT.SUMMARY_IMAGE_CAPABILITY)
-        or (force_model == GPT.LONG_MODEL and GPT.LONG_IMAGE_CAPABILITY)
+    # merge force model config
+    if model_type == "text" or (
+        model_type == "image"  # check capabilities
+        and (
+            (force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
+            or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
+            or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
+            or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
+            or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
+            or (force_model == GPT.SUMMARY_MODEL and GPT.SUMMARY_IMAGE_CAPABILITY)
+            or (force_model == GPT.LONG_MODEL and GPT.LONG_IMAGE_CAPABILITY)
+        )
     ):
         client |= force_model_config
-        return client, force_model, force_model_name
-    return client, model, model_name
+        model = force_model
+        model_name = force_model_name
+
+    return {
+        "friendly_name": model_name,
+        "client": client,
+        "completions": {
+            "model": model,
+            "messages": contexts,
+            "temperature": float(GPT.TEMPERATURE),
+        },
+    }
src/llm/prompts.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+
 from loguru import logger
 
 from config import TZ
@@ -64,7 +65,6 @@ def add_search_results_to_prompts(search_results: list[dict], params: dict) -> d
     else:  # list, multi-modality
         contexts[-1]["content"].insert(0, {"type": "text", "text": prompt})
     params["messages"] = contexts
-    params["messages"] = refine_prompts(params["messages"])
     return params
 
 
src/llm/README.md
@@ -1,13 +0,0 @@
-# GPT调用流程
-
-程序主入口为 `llm/gpt.py` 的 `gpt_response` 函数。
-
-接到消息后, 首先解析出本消息的所有回复消息组成历史上下文,然后根据消息内容判断判断调用哪种类型的GPT。(文本 or 图片)
-
-目前我们使用OpenRouter接口站, 主model为 `deepseek-r1`, 备用model为 `gpt-4o`。
-
-由于`deepseek-r1` 不支持 `function call` 功能,为了联网搜索最新消息,所有我们的调用流程分为两个阶段。
-
-1. 第一阶段, 将附带`function call`的原始prompt发送给一个支持`function call`的模型 (TOOL_MODEL), 此模型会返回是否需要调用`get_online_search_result`函数以及`query`内容。我们并不关心此模型返回的`content`, 只关心是否调用`get_online_search_result`函数以及`query`内容。TOOL_MODEL模型只需要速度快且价格便宜。
-
-2. 第二阶段, 根据TOOL_MODEL的结果, 获取联网搜索结果, 将更新后的上下文和原始prompt发送给主模型 `deepseek-r1` 进行对话。
src/llm/response.py
@@ -9,6 +9,7 @@ from openai import AsyncOpenAI
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 
 from config import GPT
+from llm.hooks import pre_hooks
 from llm.utils import add_search_results_to_response, beautify_llm_response, beautify_model_name, extract_reasoning
 from messages.progress import modify_progress
 
@@ -26,6 +27,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
         {"content": str, "reasoning": str, "model": str}
     """
     try:
+        pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
         openai = AsyncOpenAI(**config["client"])
         logger.trace(config)
         resp = await openai.chat.completions.create(**config["completions"])
src/llm/response_stream.py
@@ -12,6 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXP
 from pyrogram.types import Message
 
 from config import GPT, TEXT_LENGTH
+from llm.hooks import pre_hooks
 from llm.utils import BOT_TIPS, add_search_results_to_response, beautify_llm_response
 from messages.progress import modify_progress
 from messages.utils import count_without_entities, smart_split
@@ -26,6 +27,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
     # ruff: noqa: RUF001, RUF003
     prefix = f"🤖**{config['friendly_name']}**: ({BOT_TIPS})\n"
     try:
+        pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
         openai = AsyncOpenAI(**config["client"])
         logger.trace(config)
         answers = prefix
src/llm/summary.py
@@ -10,7 +10,7 @@ from pyrogram.client import Client
 from pyrogram.types import Chat, Message
 
 from config import GPT, MAX_MESSAGE_SUMMARY, PREFIX, TID, TZ
-from llm.models import get_model_config_with_contexts
+from llm.models import get_gpt_config
 from llm.prompts import refine_prompts
 from llm.response import send_to_gpt
 from llm.utils import count_tokens
@@ -139,7 +139,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     msg += f"🔢有效消息: {len(parsed['user_context'])}\n"
     msg += f"🔠总Token: {total_tokens}"
     await modify_progress(text=msg, force_update=True, **kwargs)
-    config = get_model_config_with_contexts(model_type="text", contexts=contexts, force_model=summary_model, message_info=info)
+    config = get_gpt_config(model_type="text", contexts=contexts, force_model=summary_model)
 
     # set max_tokens for the model
     if "o1" in summary_model or "o3" in summary_model:  # o1 or newer models use `max_completion_tokens`
src/llm/tools.py
@@ -8,7 +8,6 @@ from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 
 from config import GPT, PROXY, TOKEN, TZ
-from llm.models import hooks
 from llm.prompts import add_search_results_to_prompts, modify_prompts
 from llm.response import send_to_gpt
 from llm.tool_scheme import ONLINE_SEARCH
@@ -144,14 +143,13 @@ async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
     }
     tool_completions = add_tools(tool_completions)
     tool_client = {k: v for k, v in config["client"].items() if k != "http_client"} | {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY}
-    hooks(tool_client, tool_completions)
     tools_config = {
         "friendly_name": config["friendly_name"],
         "client": tool_client,
         "completions": tool_completions,
     }
     try:
-        response = await send_to_gpt(tools_config, retry=0)
+        response = await send_to_gpt(tools_config, retry=0, **kwargs)
         tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
         if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
             return config, response