Commit a268126

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-15 05:24:14
refactor(summary): refactor AI summary command
1 parent 439b4f3
Changed files (5)
src/llm/gemini.py
@@ -55,6 +55,7 @@ async def gemini_response(
     disable_thinking: bool = False,
     include_thoughts: bool = True,
     system_prompt: str | None = None,
+    silent: bool = False,
     **kwargs,
 ) -> dict:
     r"""Get Gemini response.
@@ -64,9 +65,12 @@ async def gemini_response(
         message (Message): The trigger message object.
         conversations (list[Message]): list of chat conversations.
         modality (str): response modality
+        enable_tools (bool, optional): Whether to enable tools. Defaults to True.
         append_grounding (bool, optional): Whether to append grounding to the response. Defaults to True.
         disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
         include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
+        system_prompt (str | None, optional): System prompt. Defaults to None.
+        silent (bool, optional): Whether to disable progressing. Defaults to False.
     """
     info = parse_msg(message)
     model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
@@ -84,8 +88,8 @@ async def gemini_response(
     try:
         real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
         msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
-        status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
-        kwargs["progress"] = status_msg
+        if not silent and kwargs.get("show_progress"):
+            kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
         genconfig |= {"response_modalities": response_modalities}
         if enable_tools and modality == "text":
             genconfig |= {"tools": tools}
@@ -101,7 +105,7 @@ async def gemini_response(
         logger.trace(params)
         if modality == "image":
             return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
-        return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, **kwargs)
+        return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, silent=silent, **kwargs)
     except Exception as e:
         logger.error(e)
     return {}
src/llm/summary.py
@@ -7,20 +7,19 @@ from datetime import datetime, timedelta
 from zoneinfo import ZoneInfo
 
 from loguru import logger
-from openai import DefaultAsyncHttpxClient
 from pyrogram.client import Client
 from pyrogram.types import Chat, Message
+from pyrogram.types.messages_and_media.message import Str
 
-from config import GPT, MAX_MESSAGE_SUMMARY, PREFIX, PROXY, TID, TZ
-from llm.prompts import refine_prompts
-from llm.response import send_to_gpt
-from llm.utils import count_tokens, sample_key
-from messages.chat_history import get_parsed_chat_history
+from config import GPT, MAX_MESSAGE_SUMMARY, PREFIX, TID, TZ
+from llm.gpt import gpt_response
+from llm.utils import BOT_TIPS, count_tokens
+from messages.chat_history import get_history_info_list
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import equal_prefix, startswith_prefix, to_int
-from utils import nowdt
+from utils import nowdt, rand_number
 
 HELP = f"""🤖**AI总结历史消息** (最多{MAX_MESSAGE_SUMMARY}条)
 ⚠️使用`{PREFIX.COMBINATION}`命令只生成聊天记录文件, 不进行AI总结
@@ -59,6 +58,36 @@ HELP = f"""🤖**AI总结历史消息** (最多{MAX_MESSAGE_SUMMARY}条)
 - 用上述各种`{PREFIX.AI_SUMMARY}`命令回复消息M, 视为将截止时间设为消息M的发送时间
 - 如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
 """
+
+SYSTEM_PROMPT = """总结以下网络聊天记录, 识别关键主题、争议话题以及重要观点。提供一个简明的总结, 保留原始意图和上下文。如有必要, 引用原始用户名和时间戳。
+每一条消息的格式如下:
+{
+    "username": "消息发送者",
+    "time": "消息发送时间",
+    "url": "消息链接",
+    "message": "本条消息内容",
+    "reply_to_message": "被此条消息回复的消息"
+}
+
+# 步骤
+1. 阅读聊天记录: 仔细查看对话内容, 了解讨论的流程和上下文。
+2. 识别关键主题: 提取整个聊天中讨论的主要话题。
+3. 忽略废话及无关内容, 专注于关键信息。
+4. 突出争议话题: 记录任何分歧或意见不同的地方。
+5. 识别重要观点: 捕捉参与者提出的重要观点或论点。
+6. 保留意图和上下文: 确保总结反映对话的原始意义和上下文。
+7. 引用用户名和时间戳: 在适当情况下, 引用用户名和时间戳以为某些陈述提供上下文。
+8. 撰写总结: 以简洁的语言编写总结, 同时包含必要的引用。
+
+# 输出格式
+- 使用中文撰写总结。
+- 简明扼要地总结聊天记录的内容。
+- 在必要时引用用户名和时间。
+- 保持清晰和简洁的表达。
+- 引用用户名时, 请使用 **username** 格式。如: **username**
+- 引用时间时, 请使用 [HH:MM:SS](url) 格式。如: [12:30:00](https://t.me/c/1234/56789)
+"""
+
 DAILY_SUMMARY_PREFIX = "🏪**#爬楼助手**\n"
 CONTEXT_FILENAME = "聊天记录.txt"
 
@@ -122,126 +151,96 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"📝正在获取历史消息...\n⏩开始时间: {begin_time:%m-%d %H:%M:%S}\n⏯️结束时间: {end_time:%m-%d %H:%M:%S}", **kwargs)
         kwargs["progress"] = res[0]
-    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, begin_time, end_time, filter_users)
+    history_list = await get_history_info_list(client, info["cid"], offset_id, num_history, begin_time, end_time, filter_users)
     # parse the history contexts
-    parsed = await get_contexts(history)
-    if not parsed["txt_format"]:
+    parsed = await parse_history_list(history_list)
+    if parsed["num_message"] == 0:
         await send2tg(client, message, texts=f"{num_history}条历史消息中未找到符合条件的消息", **kwargs)
         await modify_progress(del_status=True, **kwargs)
         return
-    contexts = refine_prompts(parsed["system_context"] + [{"role": "user", "content": parsed["user_context"]}])
-    sysmtem_tokens = count_tokens(contexts[0]["content"])
-    user_tokens = count_tokens(contexts[-1]["content"])
-    total_tokens = sysmtem_tokens + user_tokens
-    summary_model = GPT.SUMMARY_MODEL
-    summary_model_name = GPT.SUMMARY_MODEL_NAME
-    max_tokens = int(GPT.SUMMARY_MODEL_MAX_OUTPUT_LENGTH)
+    num_tokens = count_tokens(SYSTEM_PROMPT + parsed["history"])
     msg = f"⏩开始时间: {parsed['begin_time']:%m-%d %H:%M:%S}\n"
     msg += f"⏯️结束时间: {parsed['end_time']:%m-%d %H:%M:%S}\n"
-    msg += f"🔢消息条数: {len(parsed['user_context'])}\n"
-    msg += f"🔠Token数: {total_tokens}"
+    msg += f"🔢消息条数: {parsed['num_message']}\n"
+    msg += f"🔠Token数: {num_tokens}"
     # send contexts as txt file
     with io.BytesIO(parsed["txt_format"].encode("utf-8")) as f:
         await client.send_document(to_int(message.chat.id), f, file_name=CONTEXT_FILENAME, caption=msg)
     if not need_summay:
         await modify_progress(del_status=True, **kwargs)
         return
-    await modify_progress(text=f"🤖**{summary_model_name}**总结中...\n{msg}", force_update=True, **kwargs)
-    config = {
-        "friendly_name": summary_model_name,
-        "client": {"api_key": sample_key(GPT.SUMMARY_API_KEY), "base_url": GPT.SUMMARY_BASE_URL, "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT)},
-        "completions": {"model": summary_model},
-    }
-    config["completions"]["messages"] = contexts
-    # set max_tokens for the model
-    if "o1" in summary_model or "o3" in summary_model:  # o1 or newer models use `max_completion_tokens`
-        config["completions"]["max_completion_tokens"] = max_tokens
-    else:
-        config["completions"]["max_tokens"] = max_tokens
-    config["client"]["timeout"] = int(GPT.SUMMARY_TIMEOUT)
-    response = await send_to_gpt(config, **kwargs)
-    if texts := response.get("content"):
-        texts = texts.strip("`")
+    await modify_progress(text=f"🤖AI总结中...\n{msg}", force_update=True, **kwargs)
+    # Construct a message to call GPT
+    ai_msg = Message(
+        id=rand_number(),
+        chat=message.chat,
+        text=Str(GPT.SUMMARY_CMD),
+        reply_to_message=Message(id=rand_number(), chat=message.chat, text=Str(parsed["history"])),
+    )
+    response = await gpt_response(
+        client,
+        ai_msg,
+        system_prompt=SYSTEM_PROMPT,
+        enable_gpt_tools=False,
+        include_thoughts=False,
+        append_grounding=False,
+        silent=True,
+    )
+    if texts := response.get("texts"):
         if summary_prefix is None:
-            summary_prefix = f"🤖**{summary_model_name}**:\n"
+            model_name = response.get("model_name", "AI总结")
+            summary_prefix = f"🤖**{model_name}**:\n"
         kwargs["reply_msg_id"] = -1  # DO NOT send as a reply message
         await send2tg(client, message, texts=f"{summary_prefix}⏩开始时间: {begin_time:%m-%d %H:%M:%S}\n⏯️结束时间: {end_time:%m-%d %H:%M:%S}\n{texts}", **kwargs)
         await modify_progress(del_status=True, **kwargs)
 
 
-async def get_contexts(history: list[dict]) -> dict:
-    """Get GPT contexts based on parsed chat history.
+async def parse_history_list(info_list: list[dict]) -> dict:
+    """Parse chat history info list.
 
     Currently, we only summarize text contents.
     """
-    system_context = [
-        {
-            "role": "system",  # system prompt
-            "content": [
-                {
-                    "type": "text",
-                    "text": """总结以下网络聊天记录, 识别关键主题、争议话题以及重要观点。提供一个简明的总结, 保留原始意图和上下文。如有必要, 引用原始用户名和时间戳, 并使用清晰的语言。
-每一条消息的格式如下:
-{
-    "time": 消息发送时间,
-    "url": 消息链接,
-    "username": 消息发送者,
-    "message": 本条消息内容,
-    "reply_to_message": 回复消息的原始内容, 如果本消息并不回复其他消息, 则不存在该字段
-}
-
-# 步骤
-1. 阅读聊天记录: 仔细查看对话内容, 了解讨论的流程和上下文。
-2. 识别关键主题: 提取整个聊天中讨论的主要话题。
-3. 忽略废话及无关内容, 专注于关键信息。
-4. 突出争议话题: 记录任何分歧或意见不同的地方。
-5. 识别重要观点: 捕捉参与者提出的重要观点或论点。
-6. 保留意图和上下文: 确保总结反映对话的原始意义和上下文。
-7. 引用用户名和时间戳: 在适当情况下, 引用用户名和时间戳以为某些陈述提供上下文。
-8. 撰写总结: 以简洁的语言编写总结, 同时包含必要的引用。
-
-# 输出格式
-- 使用中文撰写总结。
-- 简明扼要地总结聊天记录的内容。
-- 在必要时引用用户名和时间。
-- 保持清晰和简洁的表达。
-- 引用用户名时, 请使用 **username** 格式。如: **username**
-- 引用时间时, 请使用 [HH:MM:SS](url) 格式。如: [12:30:00](https://t.me/c/1234/56789)
-""",
-                }
-            ],
-        }
-    ]
-    user_context = []
-    for info in history:
+    messages: list[dict] = []  # hold user messages
+    for info in info_list:
         if info["file_name"] == CONTEXT_FILENAME:
             continue
         if info["is_bot"]:  # bots
             continue
         if info["text"]:  # currently, we only include texts
-            if len(user_context) == 0:
+            if len(messages) == 0:
                 begin_time = info["datetime"]
             end_time = info["datetime"]
             media_type = f"[{info['mtype']}] " if info["mtype"] != "text" else ""
             content = {
+                "username": info["full_name"],
                 "time": f"{info['datetime']:%H:%M:%S}",
                 "url": info["message_url"],
-                "username": info["full_name"],
                 "message": media_type + info["text"],
             }
-            if reply_msg_content := get_message_by_id(history, info.get("reply_to_message_id")):
+            if reply_msg_content := get_message_by_id(info_list, info.get("reply_to_message_id")):
                 content["reply_to_message"] = reply_msg_content
-            user_context.append({"type": "text", "text": str(content)})
-    if not user_context:
+            messages.append(content)
+    if not messages:
         return {}
-    return {"system_context": system_context, "user_context": user_context, "txt_format": get_txt_format(history), "begin_time": begin_time, "end_time": end_time}
 
+    history = json.dumps(messages, ensure_ascii=False)
+    """IMPORTANT: We need to remove `BOT_TIPS` in the history!
 
-def get_txt_format(history: list[dict]) -> str:
+    Because we need to call `gpt_response` function,
+    it uses `BOT_TIPS` to check if the message is from GPT model.
+
+    If the history contains `BOT_TIPS`, the context of this message will be `model` (not `user`)
+    But `model` only message is not allowed, so we need to remove `BOT_TIPS`
+    """
+    history = history.replace(BOT_TIPS, "")
+    return {"history": history, "num_message": len(messages), "txt_format": get_txt_format(info_list), "begin_time": begin_time, "end_time": end_time}
+
+
+def get_txt_format(info_list: list[dict]) -> str:
     """Format the history as plaintext."""
     txt_format = ""
     txt_mediagroup_ids = set()  # record processed mediagroup messages
-    for info in history:
+    for info in info_list:
         if info["file_name"] == CONTEXT_FILENAME:
             continue
         if info["media_group_id"] in txt_mediagroup_ids:
@@ -251,40 +250,39 @@ def get_txt_format(history: list[dict]) -> str:
         if info["mtype"] != "text":  # not plaintext message
             # media group
             if info["media_group_id"] > 0:
-                media_types = [f"[{x['mtype']}]" for x in history if x["media_group_id"] == info["media_group_id"]]
+                media_types = [f"[{x['mtype']}]" for x in info_list if x["media_group_id"] == info["media_group_id"]]
                 txt_format += " ".join(media_types)
                 txt_mediagroup_ids.add(info["media_group_id"])
             else:
                 txt_format += f"[{info['mtype']}]"
         txt_format += info["text"]
         # append quote msg
-        reply_msg_content = get_message_by_id(history, info.get("reply_to_message_id"))
-        if reply_msg_content:
+        if reply_msg_content := get_message_by_id(info_list, info.get("reply_to_message_id")):
             txt_format += f"\n<quote>{reply_msg_content['username']}: {reply_msg_content['message']}</quote>"
         txt_format += "\n\n"
     return txt_format
 
 
-def get_message_by_id(history: list[dict], message_id: int | None = None) -> dict:
+def get_message_by_id(info_list: list[dict], message_id: int | None = None) -> dict:
     """Get message by id."""
     if not message_id:
         return {}
-    info = next((info for info in history if info["mid"] == message_id), {})
+    info = next((info for info in info_list if info["mid"] == message_id), {})
     if not info:
         return {}
     media_type = f"[{info['mtype']}] " if info["mtype"] != "text" else ""
     return {
+        "username": info["full_name"],
         "time": f"{info['datetime']:%H:%M:%S}",
         "url": info["message_url"],
-        "username": info["full_name"],
         "message": media_type + info["text"],
     }
 
 
-def get_media_group_by_id(history: list[dict], media_group_id: int | None = None) -> list[dict]:
+def get_media_group_by_id(info_list: list[dict], media_group_id: int | None = None) -> list[dict]:
     if not media_group_id:
         return []
-    return [x for x in history if x["media_group_id"] == media_group_id]
+    return [x for x in info_list if x["media_group_id"] == media_group_id]
 
 
 async def daily_summary(client: Client):
@@ -310,7 +308,7 @@ async def daily_summary(client: Client):
         logger.info(f"Summary chat {source_chat_id}, send results to {target_chat_id}")
         # fake message
         message = Message(
-            id=0,
+            id=rand_number(),
             chat=Chat(id=target_chat_id),
             text=f"/summary #{duration}h cid={to_int(source_chat_id)}",  # type: ignore
         )
src/llm/utils.py
@@ -15,7 +15,7 @@ from markitdown import MarkItDown
 from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 
 from config import DOWNLOAD_DIR, GEMINI, GPT, PREFIX, cache
-from utils import nowdt, number_to_emoji, read_text, remove_consecutive_newlines, remove_dash, remove_pound, zhcn
+from utils import nowdt, number_to_emoji, read_text, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
 
 BOT_TIPS = "(回复以继续)"  # noqa: RUF001
 REASONING_BEGIN = "🤔"  # use emoji to separate model reasoning and content
@@ -178,7 +178,7 @@ def image_emoji(capability: bool) -> str:  # noqa: FBT001
 
 
 def clean_cmd_prefix(text: str) -> str:
-    for prefix in [x.strip() for x in PREFIX.GPT.split(",") if x.strip()] + [PREFIX.GENIMG]:
+    for prefix in [*strings_list(PREFIX.GPT), PREFIX.GENIMG]:
         text = text.removeprefix(prefix).lstrip()
     return text
 
@@ -186,13 +186,13 @@ def clean_cmd_prefix(text: str) -> str:
 def clean_bot_tips(text: str) -> str:
     if not text:
         return ""
-    return re.sub(rf"(.*?){BOT_TIPS}", "", text, flags=re.DOTALL).strip()
+    return re.sub(rf"^🤖(.*?){BOT_TIPS}", "", text, flags=re.DOTALL).strip()
 
 
 def clean_reasoning(text: str) -> str:
     if not text:
         return ""
-    text = re.sub(rf"{REASONING_BEGIN}(.*?){REASONING_END}", "", text, flags=re.DOTALL).strip()
+    text = re.sub(rf"^{REASONING_BEGIN}(.*?){REASONING_END}", "", text.strip(), flags=re.DOTALL).strip()
     text = text.removeprefix(BLOCKQUOTE_EXPANDABLE_DELIM).lstrip()
     return text.removeprefix(BLOCKQUOTE_EXPANDABLE_END_DELIM).lstrip()
 
src/messages/chat_history.py
@@ -10,7 +10,7 @@ from config import MAX_MESSAGE_RETRIEVED, TZ
 from messages.parser import parse_msg
 
 
-async def get_parsed_chat_history(
+async def get_history_info_list(
     client: Client,
     chat_id: int | str,
     offset_id: int = 0,
@@ -43,7 +43,7 @@ async def get_parsed_chat_history(
             break
         if msg.empty:
             break
-        info = parse_msg(msg, silent=True)
+        info = parse_msg(msg, silent=True, use_cache=False)
         if info["datetime"] < begin_time:
             break
         if info["datetime"] > end_time:
src/config.py
@@ -336,14 +336,8 @@ class GPT:
     DOUBAO_BASE_URL = os.getenv("GPT_DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
     DOUBAO_ACCEPT_IMAGE = os.getenv("GPT_DOUBAO_ACCEPT_IMAGE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
 
-    # /summary command
-    SUMMARY_MODEL = os.getenv("GPT_SUMMARY_MODEL", "")
-    SUMMARY_MODEL_NAME = os.getenv("GPT_SUMMARY_MODEL_NAME", "")
-    SUMMARY_MODEL_MAX_OUTPUT_LENGTH = os.getenv("GPT_SUMMARY_MODEL_MAX_OUTPUT_LENGTH", "8192")  # 8K
-    SUMMARY_API_KEY = os.getenv("GPT_SUMMARY_API_KEY", "")
-    SUMMARY_BASE_URL = os.getenv("GPT_SUMMARY_BASE_URL", "https://api.openai.com/v1")
-    SUMMARY_TIMEOUT = os.getenv("GPT_SUMMARY_TIMEOUT", "600")  # should be larger than default timeout
-    SUMMARY_ACCEPT_IMAGE = os.getenv("GPT_SUMMARY_ACCEPT_IMAGE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    # AI summary (/summary)
+    SUMMARY_CMD = os.getenv("GPT_SUMMARY_CMD", "/gemini")  # add this command prefix to call AI summary
 
     # For tool_call. Some models doesn't support tool call, so we use this model to do the tool_call first.
     # Then construct the new questions for the original model.