Commit d5c7a89
Changed files (10)
src/llm/contexts.py
@@ -11,7 +11,6 @@ from pyrogram.client import Client
from pyrogram.types import Message
from config import GPT, PREFIX
-from llm.prompts import refine_prompts
from llm.utils import BOT_TIPS
from messages.parser import parse_msg
@@ -48,8 +47,7 @@ async def get_conversation_contexts(client: Client, conversations: list[Message]
"""
# parse context for each message
contexts = [await single_context(client, message) for message in conversations]
- contexts = [x for x in contexts if x] # filter out empty context
- contexts = refine_prompts(contexts)
+ contexts = [x for x in contexts if x.get("content")] # filter out empty context
return contexts[: int(GPT.HISTORY_CONTEXT)]
src/llm/gpt.py
@@ -7,7 +7,7 @@ from pyrogram.types import Message
from config import GPT, PREFIX, TEXT_LENGTH, cache
from llm.contexts import get_conversation_contexts, get_conversations
-from llm.models import get_context_type, get_model_config_with_contexts
+from llm.models import get_context_type, get_gpt_config
from llm.response import send_to_gpt
from llm.response_stream import send_to_gpt_stream
from llm.tools import merge_tools_response
@@ -44,10 +44,9 @@ def is_gpt_conversation(message: Message) -> bool:
# is replying to gpt-bot response message?
if not message.reply_to_message:
return False
-
reply_msg = message.reply_to_message
reply_info = parse_msg(reply_msg, silent=True)
- model_names = [GPT.OPENAI_MODEL_NAME, GPT.GEMINI_MODEL_NAME, GPT.DEEPSEEK_MODEL_NAME, GPT.QWEN_MODEL_NAME, GPT.DOUBAO_MODEL_NAME]
+ model_names = [GPT.OPENAI_MODEL_NAME, GPT.GEMINI_MODEL_NAME, GPT.DEEPSEEK_MODEL_NAME, GPT.QWEN_MODEL_NAME, GPT.DOUBAO_MODEL_NAME, GPT.TEXT_MODEL_NAME, GPT.IMAGE_MODEL_NAME]
return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
@@ -65,12 +64,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
if equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao"]) and not message.reply_to_message:
await send2tg(client, message, texts=HELP, **kwargs)
return
-
if not is_gpt_conversation(message):
return
# /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
- force_model = "N/A"
+ force_model = "NOT_SET"
reply_text = ""
if message.reply_to_message:
reply_info = parse_msg(message.reply_to_message, silent=True)
@@ -101,10 +99,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
return
cache.set(f"gpt-{info['cid']}-{media_group_id}", "1", ttl=120)
+ kwargs["message_info"] = info # save trigger message info
conversations = get_conversations(message)
context_type = get_context_type(conversations)
contexts = await get_conversation_contexts(client, conversations)
- config = get_model_config_with_contexts(context_type["type"], contexts, force_model, info)
+ config = get_gpt_config(context_type["type"], contexts, force_model)
msg = f"🤖**{config['friendly_name']}**: 思考中..."
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
src/llm/hooks.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from config import GPT
+from llm.prompts import refine_prompts
+from utils import unicode_to_ascii
+
+
+def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
+ pre_openrouter_hook(client, completions)
+ pre_helicone_hook(client, message_info)
+ completions["messages"] = refine_prompts(completions["messages"])
+
+
+def pre_openrouter_hook(client: dict, completions: dict) -> None:
+ """Add special parameters for OpenRouter."""
+ if "openrouter" not in client["base_url"]:
+ return
+ if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
+ completions["extra_body"] = {"models": models}
+
+
+def pre_helicone_hook(client: dict, message_info: dict | None) -> None:
+ """Add special parameters for helicone gateway."""
+ if not GPT.HELICONE_API_KEY:
+ return
+ headers = client.get("default_headers", {})
+ headers |= {
+ "Helicone-Auth": f"Bearer {GPT.HELICONE_API_KEY}",
+ }
+ message_info = message_info or {}
+ if chat_title := message_info.get("ctitle"):
+ headers |= {"Helicone-Property-Chat": unicode_to_ascii(chat_title), "Helicone-Property-ChatID": str(message_info["cid"])}
+ if user_name := message_info.get("full_name"):
+ headers |= {"Helicone-User-Id": unicode_to_ascii(user_name), "Helicone-Property-User": str(message_info["uid"])}
+ client |= {"default_headers": headers}
src/llm/models.py
@@ -5,9 +5,7 @@ from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
from config import GPT, PREFIX, PROXY
-from llm.prompts import refine_prompts
from messages.parser import parse_msg
-from utils import unicode_to_ascii
def get_context_type(conversations: list[Message]) -> dict:
@@ -28,8 +26,8 @@ def get_context_type(conversations: list[Message]) -> dict:
return res
-def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_model: str = "N/A", message_info: dict | None = None) -> dict:
- """Get GPT model config based on contexts, and return the config and adjusted contexts.
+def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "NOT_SET") -> dict:
+ """Get GPT configurations.
contexts:
[
@@ -42,56 +40,6 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
}
]
"""
- client, model, model_name = align_with_force_model(model_type, force_model)
-
- # params for `openai.chat.completions.create()`
- completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
- hooks(client, completions, message_info) # this line should be after setting `force_model``
- completions["messages"] = refine_prompts(completions["messages"]) # final refine after hooks
-
- if force_model != "N/A" and completions.get("extra_body"): # remove models fallback
- completions["extra_body"].pop("models", None) # should be after hooks
- return {
- "friendly_name": model_name,
- "client": client,
- "completions": completions,
- }
-
-
-def hooks(client: dict, completions: dict, message_info: dict | None = None):
- openrouter_hook(client, completions)
- helicone_hook(client, message_info)
-
-
-def openrouter_hook(client: dict, completions: dict) -> None:
- """Add special parameters for OpenRouter."""
- if "openrouter" not in client["base_url"]:
- return
- if models := [x.strip() for x in GPT.FALLBACK_MODELS.split(",") if x.strip()]:
- completions["extra_body"] = {"models": models}
-
-
-def helicone_hook(client: dict, message_info: dict | None) -> None:
- """Add special parameters for helicone gateway."""
- if not GPT.HELICONE_API_KEY:
- return
- headers = client.get("default_headers", {})
- headers |= {
- "Helicone-Auth": f"Bearer {GPT.HELICONE_API_KEY}",
- }
- message_info = message_info or {}
- if chat_title := message_info.get("ctitle"):
- headers |= {"Helicone-Property-Chat": unicode_to_ascii(chat_title), "Helicone-Property-ChatID": str(message_info["cid"])}
- if user_name := message_info.get("full_name"):
- headers |= {"Helicone-User-Id": unicode_to_ascii(user_name), "Helicone-Property-User": str(message_info["uid"])}
- client |= {"default_headers": headers}
-
-
-def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[dict, str, str]:
- """Align the model with the modalities if force_model is specified.
-
- For example, user use `/ds` to reply an image, but the model only support text, so we need to use switch to image model.
- """
models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL, "video": GPT.VIDEO_MODEL}
model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME, "video": GPT.VIDEO_MODEL_NAME}
apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
@@ -99,16 +47,17 @@ def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[d
model = models[model_type]
model_name = model_names[model_type]
- if force_model == "N/A":
- force_model = model
+ force_model = model if force_model == "NOT_SET" else force_model
+
# params for OpenAI client
- client = {
+ client = { # this config is based on model type (text or image)
"api_key": apis[model_type],
"base_url": urls[model_type],
"timeout": round(float(GPT.TIMEOUT)),
"http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
}
+ # align with force model
model_factory = {
GPT.OPENAI_MODEL: {"api_key": GPT.OPENAI_API_KEY, "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
GPT.GEMINI_MODEL: {"api_key": GPT.GEMINI_API_KEY, "base_url": GPT.GEMINI_BASE_URL, "model_name": GPT.GEMINI_MODEL_NAME},
@@ -122,19 +71,29 @@ def align_with_force_model(model_type: str, force_model: str = "N/A") -> tuple[d
force_model_name = force_model_config.get("model_name", model_name)
force_model_config.pop("model_name", None)
- if model_type == "text": # respect the force model
- client |= force_model_config
- return client, force_model, force_model_name
-
- if model_type == "image" and ( # check capabilities
- (force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
- or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
- or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
- or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
- or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
- or (force_model == GPT.SUMMARY_MODEL and GPT.SUMMARY_IMAGE_CAPABILITY)
- or (force_model == GPT.LONG_MODEL and GPT.LONG_IMAGE_CAPABILITY)
+ # merge force model config
+ if model_type == "text" or (
+ model_type == "image" # check capabilities
+ and (
+ (force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
+ or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
+ or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
+ or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
+ or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
+ or (force_model == GPT.SUMMARY_MODEL and GPT.SUMMARY_IMAGE_CAPABILITY)
+ or (force_model == GPT.LONG_MODEL and GPT.LONG_IMAGE_CAPABILITY)
+ )
):
client |= force_model_config
- return client, force_model, force_model_name
- return client, model, model_name
+ model = force_model
+ model_name = force_model_name
+
+ return {
+ "friendly_name": model_name,
+ "client": client,
+ "completions": {
+ "model": model,
+ "messages": contexts,
+ "temperature": float(GPT.TEMPERATURE),
+ },
+ }
src/llm/prompts.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+
from loguru import logger
from config import TZ
@@ -64,7 +65,6 @@ def add_search_results_to_prompts(search_results: list[dict], params: dict) -> d
else: # list, multi-modality
contexts[-1]["content"].insert(0, {"type": "text", "text": prompt})
params["messages"] = contexts
- params["messages"] = refine_prompts(params["messages"])
return params
src/llm/README.md
@@ -1,13 +0,0 @@
-# GPT调用流程
-
-程序主入口为 `llm/gpt.py` 的 `gpt_response` 函数。
-
-接到消息后, 首先解析出本消息的所有回复消息组成历史上下文,然后根据消息内容判断判断调用哪种类型的GPT。(文本 or 图片)
-
-目前我们使用OpenRouter接口站, 主model为 `deepseek-r1`, 备用model为 `gpt-4o`。
-
-由于`deepseek-r1` 不支持 `function call` 功能,为了联网搜索最新消息,所有我们的调用流程分为两个阶段。
-
-1. 第一阶段, 将附带`function call`的原始prompt发送给一个支持`function call`的模型 (TOOL_MODEL), 此模型会返回是否需要调用`get_online_search_result`函数以及`query`内容。我们并不关心此模型返回的`content`, 只关心是否调用`get_online_search_result`函数以及`query`内容。TOOL_MODEL模型只需要速度快且价格便宜。
-
-2. 第二阶段, 根据TOOL_MODEL的结果, 获取联网搜索结果, 将更新后的上下文和原始prompt发送给主模型 `deepseek-r1` 进行对话。
src/llm/response.py
@@ -9,6 +9,7 @@ from openai import AsyncOpenAI
from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
from config import GPT
+from llm.hooks import pre_hooks
from llm.utils import add_search_results_to_response, beautify_llm_response, beautify_model_name, extract_reasoning
from messages.progress import modify_progress
@@ -26,6 +27,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
{"content": str, "reasoning": str, "model": str}
"""
try:
+ pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
openai = AsyncOpenAI(**config["client"])
logger.trace(config)
resp = await openai.chat.completions.create(**config["completions"])
src/llm/response_stream.py
@@ -12,6 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXP
from pyrogram.types import Message
from config import GPT, TEXT_LENGTH
+from llm.hooks import pre_hooks
from llm.utils import BOT_TIPS, add_search_results_to_response, beautify_llm_response
from messages.progress import modify_progress
from messages.utils import count_without_entities, smart_split
@@ -26,6 +27,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
# ruff: noqa: RUF001, RUF003
prefix = f"🤖**{config['friendly_name']}**: ({BOT_TIPS})\n"
try:
+ pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
openai = AsyncOpenAI(**config["client"])
logger.trace(config)
answers = prefix
src/llm/summary.py
@@ -10,7 +10,7 @@ from pyrogram.client import Client
from pyrogram.types import Chat, Message
from config import GPT, MAX_MESSAGE_SUMMARY, PREFIX, TID, TZ
-from llm.models import get_model_config_with_contexts
+from llm.models import get_gpt_config
from llm.prompts import refine_prompts
from llm.response import send_to_gpt
from llm.utils import count_tokens
@@ -139,7 +139,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
msg += f"🔢有效消息: {len(parsed['user_context'])}\n"
msg += f"🔠总Token: {total_tokens}"
await modify_progress(text=msg, force_update=True, **kwargs)
- config = get_model_config_with_contexts(model_type="text", contexts=contexts, force_model=summary_model, message_info=info)
+ config = get_gpt_config(model_type="text", contexts=contexts, force_model=summary_model)
# set max_tokens for the model
if "o1" in summary_model or "o3" in summary_model: # o1 or newer models use `max_completion_tokens`
src/llm/tools.py
@@ -8,7 +8,6 @@ from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from config import GPT, PROXY, TOKEN, TZ
-from llm.models import hooks
from llm.prompts import add_search_results_to_prompts, modify_prompts
from llm.response import send_to_gpt
from llm.tool_scheme import ONLINE_SEARCH
@@ -144,14 +143,13 @@ async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
}
tool_completions = add_tools(tool_completions)
tool_client = {k: v for k, v in config["client"].items() if k != "http_client"} | {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY}
- hooks(tool_client, tool_completions)
tools_config = {
"friendly_name": config["friendly_name"],
"client": tool_client,
"completions": tool_completions,
}
try:
- response = await send_to_gpt(tools_config, retry=0)
+ response = await send_to_gpt(tools_config, retry=0, **kwargs)
tool_call = glom(response, "choices.0.message.tool_calls.0", default={})
if not tool_call or glom(tool_call, "function.name", default="") != "get_online_search_result":
return config, response