Commit 36e783e
Changed files (4)
src
src/llm/models.py
@@ -72,7 +72,14 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
client["api_key"] = GPT.DEEPSEEK_API_KEY
client["base_url"] = GPT.DEEPSEEK_BASE_URL
model_name = GPT.DEEPSEEK_MODEL_NAME
-
+ elif force_model == GPT.SUMMARY_MODEL:
+ client["api_key"] = GPT.SUMMARY_API_KEY
+ client["base_url"] = GPT.SUMMARY_BASE_URL
+ model_name = GPT.SUMMARY_MODEL_NAME
+ elif force_model == GPT.LONG_MODEL:
+ client["api_key"] = GPT.LONG_API_KEY
+ client["base_url"] = GPT.LONG_BASE_URL
+ model_name = GPT.LONG_MODEL_NAME
client = helicone_hook(client, message_info) # this line should be after setting `force_model``
# params for `openai.chat.completions.create()`
src/llm/summary.py
@@ -1,17 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import json
import re
from loguru import logger
-from openai import DefaultAsyncHttpxClient
from pyrogram.client import Client
from pyrogram.types import Message
-from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, cache
-from llm.models import openrouter_hook
+from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, cache
+from llm.models import get_model_config_with_contexts
from llm.prompts import refine_prompts
from llm.response import send_to_gpt
+from llm.utils import count_tokens
from messages.chat_history import get_parsed_chat_history
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -66,7 +65,6 @@ async def ai_summary(client: Client, message: Message, **kwargs):
filter_user = ""
else:
return
-
# reply a message with /summary
offset_id = info["mid"]
if message.reply_to_message:
@@ -79,16 +77,42 @@ async def ai_summary(client: Client, message: Message, **kwargs):
info["mid"] = int(matched.group(1))
offset_id = info["mid"] + 1 # include this message
+ if kwargs.get("show_progress") and "progress" not in kwargs:
+ res = await send2tg(client, message, texts=f"📝正在获取{num_history}条历史消息...", **kwargs)
+ kwargs["progress"] = res[0]
+
history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, filter_user)
if not history:
await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
+ await modify_progress(del_status=True, **kwargs)
return
- contexts = await get_contexts(client, history)
- config = get_summay_model(contexts)
- msg = f"🤖{config['friendly_name']}: 总结中..."
- if kwargs.get("show_progress"):
- res = await send2tg(client, message, texts=msg, **kwargs)
- kwargs["progress"] = res[0]
+
+ # parse the history contexts
+ parsed = await get_contexts(client, history, **kwargs)
+ contexts = refine_prompts(parsed["system_context"] + [{"role": "user", "content": parsed["user_context"]}])
+ sysmtem_tokens = count_tokens(contexts[0]["content"])
+ user_tokens = count_tokens(contexts[-1]["content"])
+ total_tokens = sysmtem_tokens + user_tokens
+ if total_tokens < int(GPT.SUMMARY_MODEL_MAX_INPUT_LENGTH):
+ summary_model = GPT.SUMMARY_MODEL
+ summary_model_name = GPT.SUMMARY_MODEL_NAME
+ max_tokens = int(GPT.SUMMARY_MODEL_MAX_OUTPUT_LENGTH)
+ else:
+ summary_model = GPT.LONG_MODEL
+ summary_model_name = GPT.LONG_MODEL_NAME
+ max_tokens = int(GPT.LONG_MODEL_MAX_OUTPUT_LENGTH)
+ msg = f"🤖**{summary_model_name}**总结中...\n"
+ msg += f"🔢有效消息条数: {len(parsed['user_context'])}\n"
+ msg += f"🔠总Token数量: {total_tokens}"
+ await modify_progress(text=msg, force_update=True, **kwargs)
+ config = get_model_config_with_contexts(model_type="text", contexts=contexts, force_model=summary_model, message_info=info)
+
+ # set max_tokens for the model
+ if "o1" in summary_model or "o3" in summary_model: # o1 or newer models use `max_completion_tokens`
+ config["completions"]["max_completion_tokens"] = max_tokens
+ else:
+ config["completions"]["max_tokens"] = max_tokens
+
response = await send_to_gpt(config, **kwargs)
if texts := response.get("content"):
logger.debug(response)
@@ -96,47 +120,12 @@ async def ai_summary(client: Client, message: Message, **kwargs):
await modify_progress(del_status=True, **kwargs)
-def get_summay_model(contexts: list[dict]) -> dict:
- """Get the model for the summary."""
- models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL}
- model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME}
- apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY}
- urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL}
- # model_type = "image" if "photo" in {x["mtype"] for x in history} else "text"
- model_type = "text" # only text model for now
- model = models[model_type]
- config = {
- "model": model,
- "friendly_name": model_names[model_type],
- "timeout": round(float(GPT.TIMEOUT)),
- "base_url": urls[model_type],
- "key": apis[model_type],
- "temperature": float(GPT.TEMPERATURE),
- }
- completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
- completions |= openrouter_hook(base_url=urls[model_type])
-
- config = {
- "friendly_name": model_names[model_type],
- "client": {
- "api_key": apis[model_type],
- "base_url": urls[model_type],
- "timeout": round(float(GPT.TIMEOUT)),
- "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
- },
- "completions": completions,
- }
-
- logger.trace(config)
- return config
-
-
-async def get_contexts(client: Client, history: list[dict]) -> list[dict]: # noqa: ARG001
+async def get_contexts(client: Client, history: list[dict], **kwargs) -> dict: # noqa: ARG001
"""Get GPT contexts based on parsed chat history.
Currently, we only summarize text contents.
"""
- contexts = [
+ system_context = [
{
"role": "system", # system prompt
"content": [
@@ -171,28 +160,37 @@ async def get_contexts(client: Client, history: list[dict]) -> list[dict]: # no
],
}
]
- user_contexts = []
+ user_context = []
for info in history:
if info["text"].startswith("/"): # commands
continue
- if info["mtype"] == "text" and info["text"]:
- content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
- user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
- # continue
- # if info["mtype"] == "photo":
- # content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": "[image]如下"}
- # user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
- # res: BytesIO = await client.download_media(info["file_id"], in_memory=True) # type: ignore
- # ext = Path(res.name).suffix.removeprefix(".").replace("jpg", "jpeg")
- # b64 = base64.b64encode(res.getvalue()).decode("utf-8")
- # user_contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}})
- # if info["text"]:
- # content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
- # user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
- # else:
- # content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": f"[{info['mtype']}] {info['text']}".strip()}
- # user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
- contexts.append({"role": "user", "content": user_contexts})
- contexts = refine_prompts(contexts)
- logger.trace(contexts)
- return contexts
+
+ if info["text"].startswith("👤"): # social media
+ continue
+
+ if info["text"]: # currently, we only include texts
+ content = {
+ "message_id": info["mid"],
+ "time": f"{info['datetime']:%H:%M:%S}",
+ "username": info["full_name"],
+ "content": info["text"],
+ }
+ if (reply_to_message_id := info.get("reply_to_message_id")) and (reply_msg_content := get_message_by_id(reply_to_message_id, history)):
+ content["reply_to_message"] = reply_msg_content
+ user_context.append({"type": "text", "text": str(content)})
+
+ return {"system_context": system_context, "user_context": user_context}
+
+
+def get_message_by_id(message_id: int, history: list[dict]) -> dict:
+ """Get message by id."""
+ info = next((info for info in history if info["mid"] == message_id), {})
+ if not info:
+ return {}
+
+ return {
+ "message_id": info["mid"],
+ "time": f"{info['datetime']:%H:%M:%S}",
+ "username": info["full_name"],
+ "content": info["text"],
+ }
src/messages/chat_history.py
@@ -33,6 +33,8 @@ async def get_parsed_chat_history(
if msg.empty:
break
info = parse_msg(msg, silent=True)
+ if msg.reply_to_message_id:
+ info["reply_to_message_id"] = msg.reply_to_message_id
if not user:
history.append(info)
continue
src/config.py
@@ -19,7 +19,7 @@ CAPTION_LENGTH = int(os.getenv("CAPTION_LENGTH", "1024")) # 4096 for Premium us
MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", "2000")) * 1024 * 1024 # 4000 MB for Premium user
ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "600"))
MAX_MESSAGE_COMBINATION = int(os.getenv("MAX_MESSAGE_COMBINATION", "5000")) # Maximum number of messages to combine
-MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "1000")) # Maximum number of messages to summay
+MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "5000")) # Maximum number of messages to summay
READING_SPEED = int(os.getenv("READING_SPEED", "300")) # words per minute
DAILY_MESSAGES = os.getenv("DAILY_MESSAGES", "{}") # Useful for daily checkin for some services. Should be a json string: '{"chat-1": "msg-1", "chat-2": "msg-2"}'
# For ytdlp downloaded video, re-encoding to H264 format. This set the max file size for re-encoding. 0 means no limit
@@ -185,6 +185,20 @@ class GPT: # see `llm/README.md`
DEEPSEEK_MODEL_NAME = os.getenv("GPT_DEEPSEEK_MODEL_NAME", "DeepSeek-R1")
DEEPSEEK_API_KEY = os.getenv("GPT_DEEPSEEK_API_KEY", "")
DEEPSEEK_BASE_URL = os.getenv("GPT_DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
+ # /summary command
+ SUMMARY_MODEL = os.getenv("GPT_SUMMARY_MODEL", "gpt-4o")
+ SUMMARY_MODEL_NAME = os.getenv("GPT_SUMMARY_MODEL_NAME", "GPT-4o")
+ SUMMARY_MODEL_MAX_INPUT_LENGTH = os.getenv("GPT_SUMMARY_MODEL_MAX_INPUT_LENGTH", "57344") # 56K
+ SUMMARY_MODEL_MAX_OUTPUT_LENGTH = os.getenv("GPT_SUMMARY_MODEL_MAX_OUTPUT_LENGTH", "8192") # 8K
+ SUMMARY_API_KEY = os.getenv("GPT_SUMMARY_API_KEY", "")
+ SUMMARY_BASE_URL = os.getenv("GPT_SUMMARY_BASE_URL", "https://api.openai.com/v1")
+ # long context model
+ LONG_MODEL = os.getenv("GPT_LONG_MODEL", "gemini-1.5-pro")
+ LONG_MODEL_NAME = os.getenv("GPT_LONG_MODEL_NAME", "Gemini-1.5-Pro")
+ LONG_MODEL_MAX_INPUT_LENGTH = os.getenv("GPT_LONG_MODEL_MAX_INPUT_LENGTH", "2097152") # 2M
+ LONG_MODEL_MAX_OUTPUT_LENGTH = os.getenv("GPT_LONG_MODEL_MAX_OUTPUT_LENGTH", "8192") # 8K
+ LONG_API_KEY = os.getenv("GPT_LONG_API_KEY", "")
+ LONG_BASE_URL = os.getenv("GPT_LONG_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
class TID: # see more TID usecase in `src/permission.py`