Commit 36e783e

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-20 16:57:09
refactor(summary): inculde the reply message in the summary context
add support for specifying the model for the `/summary` command
1 parent 1b4378d
Changed files (4)
src/llm/models.py
@@ -72,7 +72,14 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_
         client["api_key"] = GPT.DEEPSEEK_API_KEY
         client["base_url"] = GPT.DEEPSEEK_BASE_URL
         model_name = GPT.DEEPSEEK_MODEL_NAME
-
+    elif force_model == GPT.SUMMARY_MODEL:
+        client["api_key"] = GPT.SUMMARY_API_KEY
+        client["base_url"] = GPT.SUMMARY_BASE_URL
+        model_name = GPT.SUMMARY_MODEL_NAME
+    elif force_model == GPT.LONG_MODEL:
+        client["api_key"] = GPT.LONG_API_KEY
+        client["base_url"] = GPT.LONG_BASE_URL
+        model_name = GPT.LONG_MODEL_NAME
     client = helicone_hook(client, message_info)  # this line should be after setting `force_model``
 
     # params for `openai.chat.completions.create()`
src/llm/summary.py
@@ -1,17 +1,16 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import json
 import re
 
 from loguru import logger
-from openai import DefaultAsyncHttpxClient
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, cache
-from llm.models import openrouter_hook
+from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, cache
+from llm.models import get_model_config_with_contexts
 from llm.prompts import refine_prompts
 from llm.response import send_to_gpt
+from llm.utils import count_tokens
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -66,7 +65,6 @@ async def ai_summary(client: Client, message: Message, **kwargs):
         filter_user = ""
     else:
         return
-
     # reply a message with /summary
     offset_id = info["mid"]
     if message.reply_to_message:
@@ -79,16 +77,42 @@ async def ai_summary(client: Client, message: Message, **kwargs):
         info["mid"] = int(matched.group(1))
         offset_id = info["mid"] + 1  # include this message
 
+    if kwargs.get("show_progress") and "progress" not in kwargs:
+        res = await send2tg(client, message, texts=f"📝正在获取{num_history}条历史消息...", **kwargs)
+        kwargs["progress"] = res[0]
+
     history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, filter_user)
     if not history:
         await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
+        await modify_progress(del_status=True, **kwargs)
         return
-    contexts = await get_contexts(client, history)
-    config = get_summay_model(contexts)
-    msg = f"🤖{config['friendly_name']}: 总结中..."
-    if kwargs.get("show_progress"):
-        res = await send2tg(client, message, texts=msg, **kwargs)
-        kwargs["progress"] = res[0]
+
+    # parse the history contexts
+    parsed = await get_contexts(client, history, **kwargs)
+    contexts = refine_prompts(parsed["system_context"] + [{"role": "user", "content": parsed["user_context"]}])
+    sysmtem_tokens = count_tokens(contexts[0]["content"])
+    user_tokens = count_tokens(contexts[-1]["content"])
+    total_tokens = sysmtem_tokens + user_tokens
+    if total_tokens < int(GPT.SUMMARY_MODEL_MAX_INPUT_LENGTH):
+        summary_model = GPT.SUMMARY_MODEL
+        summary_model_name = GPT.SUMMARY_MODEL_NAME
+        max_tokens = int(GPT.SUMMARY_MODEL_MAX_OUTPUT_LENGTH)
+    else:
+        summary_model = GPT.LONG_MODEL
+        summary_model_name = GPT.LONG_MODEL_NAME
+        max_tokens = int(GPT.LONG_MODEL_MAX_OUTPUT_LENGTH)
+    msg = f"🤖**{summary_model_name}**总结中...\n"
+    msg += f"🔢有效消息条数: {len(parsed['user_context'])}\n"
+    msg += f"🔠总Token数量: {total_tokens}"
+    await modify_progress(text=msg, force_update=True, **kwargs)
+    config = get_model_config_with_contexts(model_type="text", contexts=contexts, force_model=summary_model, message_info=info)
+
+    # set max_tokens for the model
+    if "o1" in summary_model or "o3" in summary_model:  # o1 or newer models use `max_completion_tokens`
+        config["completions"]["max_completion_tokens"] = max_tokens
+    else:
+        config["completions"]["max_tokens"] = max_tokens
+
     response = await send_to_gpt(config, **kwargs)
     if texts := response.get("content"):
         logger.debug(response)
@@ -96,47 +120,12 @@ async def ai_summary(client: Client, message: Message, **kwargs):
         await modify_progress(del_status=True, **kwargs)
 
 
-def get_summay_model(contexts: list[dict]) -> dict:
-    """Get the model for the summary."""
-    models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL}
-    model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME}
-    apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY}
-    urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL}
-    # model_type = "image" if "photo" in {x["mtype"] for x in history} else "text"
-    model_type = "text"  # only text model for now
-    model = models[model_type]
-    config = {
-        "model": model,
-        "friendly_name": model_names[model_type],
-        "timeout": round(float(GPT.TIMEOUT)),
-        "base_url": urls[model_type],
-        "key": apis[model_type],
-        "temperature": float(GPT.TEMPERATURE),
-    }
-    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
-    completions |= openrouter_hook(base_url=urls[model_type])
-
-    config = {
-        "friendly_name": model_names[model_type],
-        "client": {
-            "api_key": apis[model_type],
-            "base_url": urls[model_type],
-            "timeout": round(float(GPT.TIMEOUT)),
-            "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
-        },
-        "completions": completions,
-    }
-
-    logger.trace(config)
-    return config
-
-
-async def get_contexts(client: Client, history: list[dict]) -> list[dict]:  # noqa: ARG001
+async def get_contexts(client: Client, history: list[dict], **kwargs) -> dict:  # noqa: ARG001
     """Get GPT contexts based on parsed chat history.
 
     Currently, we only summarize text contents.
     """
-    contexts = [
+    system_context = [
         {
             "role": "system",  # system prompt
             "content": [
@@ -171,28 +160,37 @@ async def get_contexts(client: Client, history: list[dict]) -> list[dict]:  # no
             ],
         }
     ]
-    user_contexts = []
+    user_context = []
     for info in history:
         if info["text"].startswith("/"):  # commands
             continue
-        if info["mtype"] == "text" and info["text"]:
-            content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
-            user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
-        #     continue
-        # if info["mtype"] == "photo":
-        #     content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": "[image]如下"}
-        #     user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
-        #     res: BytesIO = await client.download_media(info["file_id"], in_memory=True)  # type: ignore
-        #     ext = Path(res.name).suffix.removeprefix(".").replace("jpg", "jpeg")
-        #     b64 = base64.b64encode(res.getvalue()).decode("utf-8")
-        #     user_contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}})
-        #     if info["text"]:
-        #         content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
-        #         user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
-        # else:
-        #     content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": f"[{info['mtype']}] {info['text']}".strip()}
-        #     user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
-    contexts.append({"role": "user", "content": user_contexts})
-    contexts = refine_prompts(contexts)
-    logger.trace(contexts)
-    return contexts
+
+        if info["text"].startswith("👤"):  # social media
+            continue
+
+        if info["text"]:  # currently, we only include texts
+            content = {
+                "message_id": info["mid"],
+                "time": f"{info['datetime']:%H:%M:%S}",
+                "username": info["full_name"],
+                "content": info["text"],
+            }
+            if (reply_to_message_id := info.get("reply_to_message_id")) and (reply_msg_content := get_message_by_id(reply_to_message_id, history)):
+                content["reply_to_message"] = reply_msg_content
+            user_context.append({"type": "text", "text": str(content)})
+
+    return {"system_context": system_context, "user_context": user_context}
+
+
+def get_message_by_id(message_id: int, history: list[dict]) -> dict:
+    """Get message by id."""
+    info = next((info for info in history if info["mid"] == message_id), {})
+    if not info:
+        return {}
+
+    return {
+        "message_id": info["mid"],
+        "time": f"{info['datetime']:%H:%M:%S}",
+        "username": info["full_name"],
+        "content": info["text"],
+    }
src/messages/chat_history.py
@@ -33,6 +33,8 @@ async def get_parsed_chat_history(
         if msg.empty:
             break
         info = parse_msg(msg, silent=True)
+        if msg.reply_to_message_id:
+            info["reply_to_message_id"] = msg.reply_to_message_id
         if not user:
             history.append(info)
             continue
src/config.py
@@ -19,7 +19,7 @@ CAPTION_LENGTH = int(os.getenv("CAPTION_LENGTH", "1024"))  # 4096 for Premium us
 MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", "2000")) * 1024 * 1024  # 4000 MB for Premium user
 ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "600"))
 MAX_MESSAGE_COMBINATION = int(os.getenv("MAX_MESSAGE_COMBINATION", "5000"))  # Maximum number of messages to combine
-MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "1000"))  # Maximum number of messages to summay
+MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "5000"))  # Maximum number of messages to summay
 READING_SPEED = int(os.getenv("READING_SPEED", "300"))  # words per minute
 DAILY_MESSAGES = os.getenv("DAILY_MESSAGES", "{}")  # Useful for daily checkin for some services. Should be a json string: '{"chat-1": "msg-1", "chat-2": "msg-2"}'
 # For ytdlp downloaded video, re-encoding to H264 format. This set the max file size for re-encoding. 0 means no limit
@@ -185,6 +185,20 @@ class GPT:  # see `llm/README.md`
     DEEPSEEK_MODEL_NAME = os.getenv("GPT_DEEPSEEK_MODEL_NAME", "DeepSeek-R1")
     DEEPSEEK_API_KEY = os.getenv("GPT_DEEPSEEK_API_KEY", "")
     DEEPSEEK_BASE_URL = os.getenv("GPT_DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
+    # /summary command
+    SUMMARY_MODEL = os.getenv("GPT_SUMMARY_MODEL", "gpt-4o")
+    SUMMARY_MODEL_NAME = os.getenv("GPT_SUMMARY_MODEL_NAME", "GPT-4o")
+    SUMMARY_MODEL_MAX_INPUT_LENGTH = os.getenv("GPT_SUMMARY_MODEL_MAX_INPUT_LENGTH", "57344")  # 56K
+    SUMMARY_MODEL_MAX_OUTPUT_LENGTH = os.getenv("GPT_SUMMARY_MODEL_MAX_OUTPUT_LENGTH", "8192")  # 8K
+    SUMMARY_API_KEY = os.getenv("GPT_SUMMARY_API_KEY", "")
+    SUMMARY_BASE_URL = os.getenv("GPT_SUMMARY_BASE_URL", "https://api.openai.com/v1")
+    # long context model
+    LONG_MODEL = os.getenv("GPT_LONG_MODEL", "gemini-1.5-pro")
+    LONG_MODEL_NAME = os.getenv("GPT_LONG_MODEL_NAME", "Gemini-1.5-Pro")
+    LONG_MODEL_MAX_INPUT_LENGTH = os.getenv("GPT_LONG_MODEL_MAX_INPUT_LENGTH", "2097152")  # 2M
+    LONG_MODEL_MAX_OUTPUT_LENGTH = os.getenv("GPT_LONG_MODEL_MAX_OUTPUT_LENGTH", "8192")  # 8K
+    LONG_API_KEY = os.getenv("GPT_LONG_API_KEY", "")
+    LONG_BASE_URL = os.getenv("GPT_LONG_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
 
 
 class TID:  # see more TID usecase in `src/permission.py`