Commit a13d2f0

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-03-25 11:43:22
feat(summary): add multiple user filter support
1 parent f360e11
Changed files (3)
src/llm/summary.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import io
 import json
 import re
 from datetime import datetime, timedelta
@@ -23,15 +24,20 @@ from utils import nowdt
 
 HELP = f"""🤖**GPT总结历史消息** (最多{MAX_MESSAGE_SUMMARY}条)
 使用说明:
+# 后跟消息数量或时间范围
+@ 后跟用户名 (可多次使用@)
+
 
 **1️⃣指定条目数**
 - `{PREFIX.AI_SUMMARY} #N`: 总结最近的N条历史消息
 - `{PREFIX.AI_SUMMARY} #N @User`: 总结最近只属于User的N条消息
+- `{PREFIX.AI_SUMMARY} #N @User @User2`: 总结最近只属于User和User2的N条消息
 
 示例:
 - `{PREFIX.AI_SUMMARY} #10`: 总结最近的10条历史消息
 - `{PREFIX.AI_SUMMARY} #20 @123456`: 总结最近UID为123456的20条消息
 - `{PREFIX.AI_SUMMARY} #20 @John`: 总结最近用户John(大小写均可)的20条消息
+- `{PREFIX.AI_SUMMARY} #20 @John @Bob`: 总结最近用户John和Bob的20条消息
 
 **2️⃣指定最近时间段**
 - `{PREFIX.AI_SUMMARY} #interval`: 总结最近interval时段内的消息
@@ -54,6 +60,7 @@ HELP = f"""🤖**GPT总结历史消息** (最多{MAX_MESSAGE_SUMMARY}条)
 - 3️⃣的时间格式中没有任何分隔符, 必须为YYYYMMDDHHMMSS (14位纯数字)
 """
 DAILY_SUMMARY_PREFIX = "🏪**#爬楼助手**\n"
+CONTEXT_FILENAME = "聊天记录上下文.txt"
 
 
 async def ai_summary(client: Client, message: Message, summary_prefix: str | None = None, **kwargs):
@@ -72,7 +79,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     # get the number of messages to combine
     info = parse_msg(message)
     num_history = MAX_MESSAGE_SUMMARY
-    filter_user = ""
+    filter_users = []
     begin_time = datetime.fromtimestamp(0, tz=ZoneInfo(TZ))
     end_time = nowdt(tz=TZ)
     # reply to a message with /summary
@@ -87,12 +94,12 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     if matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d{14})-?(\d{14})?(\s+)?(@\w+)?", info["text"]):
         begin_time = datetime.strptime(matched.group(1), "%Y%m%d%H%M%S").replace(tzinfo=ZoneInfo(TZ))
         end_time = datetime.strptime(matched.group(2) or end_time.strftime("%Y%m%d%H%M%S"), "%Y%m%d%H%M%S").replace(tzinfo=ZoneInfo(TZ))
-        filter_user = matched.group(4) or ""
+        filter_users = re.findall(r"@([^\s]+)", info["text"])
     # 2️⃣ /summary #interval @user  (/summary #4h @user)
     elif matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)([mMhHdD])(\s+)?(@\w+)?", info["text"]):
         interval = int(matched.group(1))
         unit = matched.group(2).lower()
-        filter_user = matched.group(4) or ""
+        filter_users = re.findall(r"@([^\s]+)", info["text"])
         if unit == "m":
             begin_time = end_time - timedelta(minutes=interval)
         elif unit == "h":
@@ -102,7 +109,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     # 1️⃣ /summary #N @user
     elif matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)(\s+)?(@\w+)?", info["text"]):
         num_history = min(int(matched.group(1)), MAX_MESSAGE_SUMMARY)
-        filter_user = matched.group(3) or ""
+        filter_users = re.findall(r"@([^\s]+)", info["text"])
     else:
         return
     # set custom chat_id and message_id (useful for debug)
@@ -113,8 +120,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
     if kwargs.get("show_progress") and "progress" not in kwargs:
         res = await send2tg(client, message, texts=f"📝正在获取历史消息...\n⏩开始时间: {begin_time:%m-%d %H:%M:%S}\n⏯️结束时间: {end_time:%m-%d %H:%M:%S}", **kwargs)
         kwargs["progress"] = res[0]
-    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, begin_time, end_time, filter_user.removeprefix("@"))
-
+    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, begin_time, end_time, filter_users)
     # parse the history contexts
     parsed = await get_contexts(history)
     if not history:
@@ -154,6 +160,8 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
             summary_prefix = f"🤖**{response['model']}**:\n"
         await send2tg(client, message, texts=f"{summary_prefix}⏩开始时间: {begin_time:%m-%d %H:%M:%S}\n⏯️结束时间: {end_time:%m-%d %H:%M:%S}\n{texts}", **kwargs)
         await modify_progress(del_status=True, **kwargs)
+        with io.BytesIO(parsed["txt_format"].encode("utf-8")) as f:
+            await client.send_document(to_int(message.chat.id), f, file_name=CONTEXT_FILENAME)
 
 
 async def get_contexts(history: list[dict]) -> dict:
@@ -202,11 +210,12 @@ async def get_contexts(history: list[dict]) -> dict:
         }
     ]
     user_context = []
+    txt_format = ""  # simplified format, send as txt file
     for info in history:
         if info["text"].startswith("/"):  # commands
             continue
 
-        if info["text"].startswith(DAILY_SUMMARY_PREFIX):  # daily summary
+        if info["file_name"] == CONTEXT_FILENAME:
             continue
 
         if info["text"].startswith("👤"):  # social media
@@ -229,9 +238,10 @@ async def get_contexts(history: list[dict]) -> dict:
             if reply_msg_content := get_message_by_id(history, info.get("reply_to_message_id")):
                 content["reply_to_message"] = reply_msg_content
             user_context.append({"type": "text", "text": str(content)})
+            txt_format += f"[{content['time']}]{content['username']}:\n{content['message']}\n\n"
     if not user_context:
         return {}
-    return {"system_context": system_context, "user_context": user_context, "begin_time": begin_time, "end_time": end_time}
+    return {"system_context": system_context, "user_context": user_context, "txt_format": txt_format, "begin_time": begin_time, "end_time": end_time}
 
 
 def get_message_by_id(history: list[dict], message_id: int | None = None) -> dict:
src/messages/chat_history.py
@@ -17,7 +17,7 @@ async def get_parsed_chat_history(
     num: int = 0,
     begin_time: datetime | None = None,
     end_time: datetime | None = None,
-    user: str = "",
+    users: str | list[str] | None = None,
 ) -> list[dict]:
     """Get given number of chat history from old to new in parserd json format.
 
@@ -29,7 +29,11 @@ async def get_parsed_chat_history(
         end_time = datetime.now(tz=ZoneInfo(TZ))
     history = []
     retrieved = 0
-    user = user.replace(" ", "").lower()
+    if users is None:
+        users = []
+    if isinstance(users, str):
+        users = [users]
+    users = [x.replace(" ", "").lower() for x in users]
     async for msg in client.get_chat_history(chat_id=chat_id, offset_id=offset_id):  # type: ignore
         # iterate messages from new to old
         retrieved += 1
@@ -46,10 +50,10 @@ async def get_parsed_chat_history(
             continue
         if msg.reply_to_message_id:
             info["reply_to_message_id"] = msg.reply_to_message_id
-        if not user:
+        if not users:
             history.append(info)
             continue
-        if info["full_name"].replace(" ", "").lower() == user or str(info["uid"]) == user or info["handle"].lower() == user:
+        if any(info["full_name"].replace(" ", "").lower() == user or str(info["uid"]) == user or info["handle"].lower() == user for user in users):
             history.append(info)
     history.reverse()  # from old to new
     return history
src/others/combine_history.py
@@ -71,7 +71,7 @@ async def combine_history(client: Client, message: Message, **kwargs):
         info["mid"] = int(matched.group(1))
         offset_id = info["mid"] + 1  # include this message
 
-    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, user=filter_user)
+    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, users=filter_user)
     if not history:
         await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
         return