Commit 0dfd493

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-09 02:42:23
refactor(gpt): refactor model config and contexts
1 parent 3f219dd
src/llm/contexts.py
@@ -48,6 +48,8 @@ async def get_conversation_contexts(client: Client, conversations: list[Message]
     contexts = [await single_context(client, message) for message in conversations]
     contexts = [x for x in contexts if x]  # filter out empty context
     contexts = combine_consecutive_role_contexts(contexts)
+    contexts = simplify_text_contents(contexts)
+
     return contexts[: int(GPT.HISTORY_CONTEXT)]
 
 
@@ -147,3 +149,44 @@ def combine_consecutive_role_contexts(contexts: list[dict]) -> list[dict]:
         else:
             combined_contexts.append(msg)
     return combined_contexts
+
+
+def simplify_text_contents(contexts: list[dict]) -> list[dict]:
+    """Simplify the plain text content format.
+
+    Some models do not support this format:
+        [{'text': 'hi', 'type': 'text'}], 'role': 'user'}]
+
+    It only supports:
+        [{'content': 'hi', 'role': 'user'}]
+
+    Args:
+        contexts (list[dict]): [
+            {
+                "role": "user or assistant",
+                "content": [
+                    {'type': 'text', 'text': 'caption this img'},
+                ]
+            }
+        ]
+
+    Returns:
+        list[dict]: [
+            {
+                "role": "user or assistant",
+                "content": "caption this img"
+            }
+        ]
+    """
+    fixed_contexts = []
+    for msg in contexts:
+        if not msg.get("content") or not isinstance(msg.get("content"), list):
+            fixed_contexts.append(msg)
+            continue
+        contents = msg.get("content", [])
+        if all(x.get("type") == "text" for x in contents):
+            msg["content"] = "\n".join([x.get("text") for x in contents])
+            fixed_contexts.append(msg)
+        else:
+            fixed_contexts.append(msg)
+    return fixed_contexts
src/llm/gpt.py
@@ -6,7 +6,7 @@ from pyrogram.types import Message
 
 from config import ENABLE, GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
-from llm.models import get_model_type, get_model_with_contexts
+from llm.models import get_model_config_with_contexts, get_model_type
 from llm.response import get_gpt_response
 from llm.utils import llm_cleanup_files
 from messages.parser import parse_msg
@@ -71,14 +71,14 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         await send2tg(client, message, texts=model_type, **kwargs)
         return
     contexts = await get_conversation_contexts(client, conversations)
-    model_conf, contexts = get_model_with_contexts(model_type, contexts)
-    msg = f"🤖{model_conf['friendly_name']}: 思考中..."
+    config = get_model_config_with_contexts(model_type, contexts)
+    msg = f"🤖{config['friendly_name']}: 思考中..."
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    response = await get_gpt_response(model_conf, contexts, **kwargs)
-    llm_cleanup_files(contexts)
-    texts = f"{model_conf['bot_msg_prefix']}\n\n{response}"
+    response = await get_gpt_response(config, **kwargs)
+    llm_cleanup_files(config["completions"]["messages"])
+    texts = f"{config['bot_msg_prefix']}\n\n{response}"
     logger.debug(texts)
     await send2tg(client, message, texts=texts, **kwargs)
     await modify_progress(del_status=True, **kwargs)
src/llm/models.py
@@ -1,11 +1,12 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
 from loguru import logger
+from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
-from config import GPT
-from llm.utils import BOT_TIPS, simplify_text_contents
+from config import GPT, PROXY
+from llm.tool_scheme import get_tools
+from llm.utils import BOT_TIPS
 from messages.parser import parse_msg
 
 
@@ -27,7 +28,7 @@ def get_model_type(conversations: list[Message]) -> str:
     return model_type
 
 
-def get_model_with_contexts(model_type: str, contexts: list[dict]) -> tuple[dict, list[dict]]:
+def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dict:
     """Get GPT model config based on contexts, and return the config and adjusted contexts.
 
     contexts:
@@ -48,23 +49,39 @@ def get_model_with_contexts(model_type: str, contexts: list[dict]) -> tuple[dict
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
     urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL, "video": GPT.VIDEO_BASE_URL}
     model = models[model_type]
+
+    # setup configs
+    # params for OpenAI client
+    client = {
+        "api_key": apis[model_type],
+        "base_url": urls[model_type],
+        "timeout": round(float(timeouts[model_type])),
+        "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
+    }
+    # get tool call params, and adjust contexts if nessary
+    tools, contexts = get_tools(contexts)
+
+    # params for `openai.chat.completions.create()`
+    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+    completions |= tools
+    completions |= openrouter_hook(client["base_url"])
+
     config = {
-        "model": model,
         "friendly_name": model_names[model_type],
-        "timeout": round(float(timeouts[model_type])),
-        "base_url": urls[model_type],
-        "key": apis[model_type],
-        "temperature": float(GPT.TEMPERATURE),
         "bot_msg_prefix": f"🤖**{model_names[model_type]}**: ({BOT_TIPS})",
+        "client": client,
+        "completions": completions,
     }
-    contexts = simplify_text_contents(contexts)
+
     logger.trace(config)
-    return config, contexts
+    return config
 
 
-def get_fallback_models(base_url: str) -> dict:
-    """Get fallback models for OpenRouter."""
-    models = []
-    if "openrouter.ai" in base_url:
-        models = [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]
-    return {"extra_body": {"models": models}} if models else {}
+def openrouter_hook(base_url: str) -> dict:
+    """Add special parameters for OpenRouter."""
+    if "openrouter.ai" not in base_url:
+        return {}
+    params = {}
+    if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
+        params |= {"extra_body": {"models": models}}
+    return params
src/llm/response.py
@@ -4,32 +4,22 @@ import json
 
 from glom import glom
 from loguru import logger
-from openai import AsyncOpenAI, DefaultAsyncHttpxClient
+from openai import AsyncOpenAI
 
-from config import PROXY, TZ
-from llm.models import get_fallback_models
+from config import TZ
 from llm.tool_call import get_online_search_result
-from llm.tool_scheme import get_tools
 from llm.utils import change_system_prompt
 from messages.progress import modify_progress
 from utils import nowdt
 
 
-async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
+async def get_gpt_response(config: dict, **kwargs) -> str:
     """Get GPT response for text model."""
     try:
-        openai = AsyncOpenAI(
-            api_key=config["key"],
-            base_url=config["base_url"],
-            timeout=config["timeout"],
-            http_client=DefaultAsyncHttpxClient(proxy=PROXY.GPT),
-        )
-        params = {"model": config["model"], "messages": contexts, "temperature": config["temperature"]}
-        params |= get_tools()
-        params |= get_fallback_models(config["base_url"])
-        logger.trace(params)
-        resp = await openai.chat.completions.create(**params)
-        return await parse_tool_call(openai, config, resp.model_dump(), contexts, **kwargs)
+        openai = AsyncOpenAI(**config["client"])
+        logger.trace(config)
+        resp = await openai.chat.completions.create(**config["completions"])
+        return await parse_tool_call(openai, config, resp.model_dump(), **kwargs)
     except Exception as e:
         error = f"🤖{config['friendly_name']}请求失败, 请稍后重试.\n{e}"
         logger.error(f"GPT request failed: {e}")
@@ -37,16 +27,11 @@ async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
         return error
 
 
-async def parse_tool_call(
-    openai: AsyncOpenAI,
-    config: dict,
-    response: dict,
-    contexts: list[dict],
-    **kwargs,
-) -> str:
+async def parse_tool_call(openai: AsyncOpenAI, config: dict, response: dict, **kwargs) -> str:
     choice = glom(response, "choices.0", default=[])
     if not choice:
         return ""
+    logger.debug(response)
     try:
         if tool_call := glom(choice, "message.tool_calls.0", default={}):
             args = json.loads(glom(tool_call, "function.arguments", default="{}"))
@@ -56,15 +41,16 @@ async def parse_tool_call(
                 await modify_progress(text=f"正在联网搜索信息:\n{args.get('query', '')}", force_update=True, **kwargs)
                 tool_result = await get_online_search_result(**args)
                 contexts = change_system_prompt(
-                    contexts,
+                    config["completions"]["messages"],
                     f"于{nowdt(TZ):%Y-%m-%d}日进行了一次网络搜索, 请参考以下网络搜索结果进行回答. 如果结果中包含原始链接, 请以源域名附带上Markdown格式的链接. 例如: [nytimes.com](https://nytimes.com/some-page)\n网络搜索结果如下:\n{json.dumps(tool_result, ensure_ascii=False)}",
                 )
                 resp = await openai.chat.completions.create(
-                    model=config["model"],
+                    model=config["completions"]["model"],
                     messages=contexts,  # type: ignore
-                    temperature=config["temperature"],
+                    temperature=config["completions"]["temperature"],
                 )
                 response = resp.model_dump()
+                logger.debug(response)
         return glom(response, "choices.0.message.content", default="")
     except Exception as e:
         logger.error(f"GPT failed: {e}")
src/llm/summary.py
@@ -4,12 +4,14 @@ import json
 import re
 
 from loguru import logger
+from openai import DefaultAsyncHttpxClient
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, cache
+from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, cache
+from llm.contexts import combine_consecutive_role_contexts, simplify_text_contents
+from llm.models import openrouter_hook
 from llm.response import get_gpt_response
-from llm.utils import simplify_text_contents
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -81,22 +83,19 @@ async def ai_summary(client: Client, message: Message, **kwargs):
     if not history:
         await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
         return
-
-    model_conf = get_summay_model(history)
     contexts = await get_contexts(client, history)
-    if model_conf["friendly_name"].startswith("豆包"):
-        contexts = simplify_text_contents(contexts)
-    msg = f"🤖{model_conf['friendly_name']}: 总结中..."
+    config = get_summay_model(contexts)
+    msg = f"🤖{config['friendly_name']}: 总结中..."
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
         kwargs["progress"] = res[0]
-    response = await get_gpt_response(model_conf, contexts, **kwargs)
+    response = await get_gpt_response(config, **kwargs)
     logger.debug(response)
     await send2tg(client, message, texts=response.strip("`"), **kwargs)
     await modify_progress(del_status=True, **kwargs)
 
 
-def get_summay_model(history: list[dict]) -> dict:  # noqa: ARG001
+def get_summay_model(contexts: list[dict]) -> dict:
     """Get the model for the summary."""
     models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL}
     model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME}
@@ -114,6 +113,20 @@ def get_summay_model(history: list[dict]) -> dict:  # noqa: ARG001
         "key": apis[model_type],
         "temperature": float(GPT.TEMPERATURE),
     }
+    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+    completions |= openrouter_hook(base_url=urls[model_type])
+
+    config = {
+        "friendly_name": model_names[model_type],
+        "client": {
+            "api_key": apis[model_type],
+            "base_url": urls[model_type],
+            "timeout": round(float(timeouts[model_type])),
+            "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
+        },
+        "completions": completions,
+    }
+
     logger.trace(config)
     return config
 
@@ -180,5 +193,7 @@ async def get_contexts(client: Client, history: list[dict]) -> list[dict]:  # no
         #     content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": f"[{info['mtype']}] {info['text']}".strip()}
         #     user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
     contexts.append({"role": "user", "content": user_contexts})
+    contexts = combine_consecutive_role_contexts(contexts)
+    contexts = simplify_text_contents(contexts)
     logger.trace(contexts)
     return contexts
src/llm/tool_scheme.py
@@ -1,13 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from config import ENABLE
+from config import ENABLE, TZ
+from llm.utils import change_system_prompt
+from utils import nowdt
 
 ONLINE_SEARCH = {
     "type": "function",
     "function": {
         "name": "get_online_search_result",
-        "description": "",
+        "description": "获取联网搜索结果",
         "parameters": {
             "type": "object",
             "properties": {
@@ -19,10 +21,12 @@ ONLINE_SEARCH = {
 }
 
 
-def get_tools() -> dict:
+def get_tools(contexts: list[dict]) -> tuple[dict, list[dict]]:
     """Get tools for GPT.
 
-    Returns: {
+    Returns: (tools_params, contexts)
+        tools_params:
+            {
                 "tools": [{tool_1}, {tool_2}, ...],  # list of dict
                 "tool_choice": "auto",
             }
@@ -30,4 +34,6 @@ def get_tools() -> dict:
     tools = []
     if ENABLE.GPT_ONLINE_SEARCH:
         tools = [ONLINE_SEARCH]
-    return {"tools": tools, "tool_choice": "auto"} if tools else {}
+        contexts = change_system_prompt(contexts, f"你是一个具备网络访问能力的智能助手. 在需要时可以访问互联网进行相关搜索获取信息以确保用户得到最新、准确的帮助。当前日期是 {nowdt(TZ):%Y-%m-%d}")
+    tools_params = {"tools": tools, "tool_choice": "auto"} if tools else {}
+    return tools_params, contexts
src/llm/utils.py
@@ -10,47 +10,6 @@ from config import DOWNLOAD_DIR, GPT
 BOT_TIPS = "回复此消息以继续对话"
 
 
-def simplify_text_contents(contexts: list[dict]) -> list[dict]:
-    """Simplify the plain text content format.
-
-    Some models do not support this format:
-        [{'text': 'hi', 'type': 'text'}], 'role': 'user'}]
-
-    It only supports:
-        [{'content': 'hi', 'role': 'user'}]
-
-    Args:
-        contexts (list[dict]): [
-            {
-                "role": "user or assistant",
-                "content": [
-                    {'type': 'text', 'text': 'caption this img'},
-                ]
-            }
-        ]
-
-    Returns:
-        list[dict]: [
-            {
-                "role": "user or assistant",
-                "content": "caption this img"
-            }
-        ]
-    """
-    fixed_contexts = []
-    for msg in contexts:
-        if not msg.get("content") or not isinstance(msg.get("content"), list):
-            fixed_contexts.append(msg)
-            continue
-        contents = msg.get("content", [])
-        if all(x.get("type") == "text" for x in contents):
-            msg["content"] = "\n".join([x.get("text") for x in contents])
-            fixed_contexts.append(msg)
-        else:
-            fixed_contexts.append(msg)
-    return fixed_contexts
-
-
 def llm_cleanup_files(messages: list[dict]):
     """Clean downloaded files.