Commit fe359a0

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-03 18:24:08
feat: change number of messages of target user in `combine` and `summary`
1 parent af05ad8
Changed files (4)
src/llm/summary.py
@@ -7,7 +7,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import COMBINATION_MAX_HISTORY, ENABLE, GPT, PREFIX, cache
+from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, cache
 from llm.response import get_gpt_response
 from llm.utils import fix_doubao
 from messages.chat_history import get_parsed_chat_history
@@ -16,25 +16,22 @@ from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import equal_prefix, to_int
 
-HELP = f"""🤖**GPT总结历史消息** (最多{COMBINATION_MAX_HISTORY}条)
-当前模型:
-- 文本模型: **{GPT.TEXT_MODEL_NAME}**
-- 图片模型: **{GPT.IMAGE_MODEL_NAME}**
-
+HELP = f"""🤖**GPT总结历史消息** (最多{MAX_MESSAGE_SUMMARY}条)
+当前模型: **{GPT.TEXT_MODEL_NAME}**
 使用说明:
 1. `{PREFIX.AI_SUMMARY} + #N`
 GPT总结最近的N条历史消息
 
 2. `{PREFIX.AI_SUMMARY} + #N + @User`
-GPT总结最近的N条历史消息中只属于User发送的消息
+GPT总结最近只属于User的N条消息
 
 如果以 `{PREFIX.AI_SUMMARY} + #N` (或附加User) 回复消息M
 则总结消息M之前的N条消息文本 (包含M)
 
 示例:
 1. `{PREFIX.AI_SUMMARY} #10`: 总结最近的10条历史消息
-2. `{PREFIX.AI_SUMMARY} #20 @123456`: 总结最近的20条历史消息中UID为123456的消息
-3. `{PREFIX.AI_SUMMARY} #20 @John`: 总结最近20条消息中用户John(大小写均可)发送的消息
+2. `{PREFIX.AI_SUMMARY} #20 @123456`: 总结最近UID为123456的20条消息
+3. `{PREFIX.AI_SUMMARY} #20 @John`: 总结最近用户John(大小写均可)的20条消息
 如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
 """
 
@@ -59,9 +56,11 @@ async def ai_summary(client: Client, message: Message, **kwargs):
     num_history = 0
     if matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)\s+@(\w+)", info["text"]):
         num_history = int(matched.group(1))
+        num_history = min(num_history, MAX_MESSAGE_SUMMARY)
         filter_user = str(matched.group(2))
     elif matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)", info["text"]):
         num_history = int(matched.group(1))
+        num_history = min(num_history, MAX_MESSAGE_SUMMARY)
         filter_user = ""
     else:
         return
@@ -78,11 +77,7 @@ async def ai_summary(client: Client, message: Message, **kwargs):
         info["mid"] = int(matched.group(1))
         offset_id = info["mid"] + 1  # include this message
 
-    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history)
-    # filter by user
-    if filter_user:
-        history = [x for x in history if x["full_name"].replace(" ", "").lower() == filter_user.lower() or str(x["uid"]) == filter_user]
-
+    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, filter_user)
     if not history:
         await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
         return
@@ -101,14 +96,15 @@ async def ai_summary(client: Client, message: Message, **kwargs):
     await modify_progress(del_status=True, **kwargs)
 
 
-def get_summay_model(history: list[dict]) -> dict:
+def get_summay_model(history: list[dict]) -> dict:  # noqa: ARG001
     """Get the model for the summary."""
     models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL}
     model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME}
     timeouts = {"text": GPT.TEXT_TIMEOUT, "image": GPT.IMAGE_TIMEOUT}
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY}
     urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL}
-    model_type = "image" if "photo" in {x["mtype"] for x in history} else "text"
+    # model_type = "image" if "photo" in {x["mtype"] for x in history} else "text"
+    model_type = "text"  # only text model for now
     model = models[model_type]
     config = {
         "model": model,
src/messages/chat_history.py
@@ -2,29 +2,41 @@
 # -*- coding: utf-8 -*-
 
 from pyrogram.client import Client
-from pyrogram.types import Message
 
-from config import cache
+from config import MAX_MESSAGE_COMBINATION, cache
 from messages.parser import parse_msg
 
 
 @cache.memoize(ttl=10)
-async def get_history_messages(client: Client, chat_id: int | str, offset_id: int, num: int = 0) -> list[Message]:
-    """Get given number of chat history from old to new."""
+async def get_parsed_chat_history(
+    client: Client,
+    chat_id: int | str,
+    offset_id: int,
+    num: int = 0,
+    user: str = "",
+) -> list[dict]:
+    """Get given number of chat history from old to new in parserd json format.
+
+    If user is specified, number of messages from the user will be returned.
+    """
     if num <= 0:
         return []
     history = []
-    async for msg in client.get_chat_history(chat_id=chat_id, offset_id=offset_id, limit=num):  # type: ignore
+    retrieved = 0
+    user = user.replace(" ", "").lower()
+    async for msg in client.get_chat_history(chat_id=chat_id, offset_id=offset_id):  # type: ignore
+        retrieved += 1
+        if retrieved > MAX_MESSAGE_COMBINATION:
+            break
+        if len(history) >= num:
+            break
         if msg.empty:
+            break
+        info = parse_msg(msg, silent=True)
+        if not user:
+            history.append(info)
             continue
-        history.append(msg)
-    return history[::-1]
-
-
-@cache.memoize(ttl=10)
-async def get_parsed_chat_history(client: Client, chat_id: int | str, offset_id: int, num: int = 0) -> list[dict]:
-    """Get given number of chat history in parserd json format."""
-    if num <= 0:
-        return []
-    history = await get_history_messages(client, chat_id, offset_id, num)
-    return [parse_msg(msg, silent=True) for msg in history]
+        if info["full_name"].replace(" ", "").lower() == user or str(info["uid"]) == user or info["handle"].lower() == user:
+            history.append(info)
+    history.reverse()  # from old to new
+    return history
src/others/combine_history.py
@@ -6,7 +6,7 @@ import re
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import COMBINATION_MAX_HISTORY, ENABLE, PREFIX, READING_SPEED
+from config import ENABLE, MAX_MESSAGE_COMBINATION, PREFIX, READING_SPEED
 from llm.utils import count_tokens
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
@@ -15,21 +15,21 @@ from messages.utils import equal_prefix, get_reply_to, startswith_prefix
 from utils import to_int
 
 HELP = f"""
-💬**合并对话历史** (最多{COMBINATION_MAX_HISTORY}条)
+💬**合并对话历史** (最多{MAX_MESSAGE_COMBINATION}条)
 使用说明:
 1. `{PREFIX.COMBINATION} + #N`
 将最近的N条消息文本合并为txt文件
 
 2. `{PREFIX.COMBINATION} + #N + @User`
-将最近的N条消息中只属于User发送的消息合并为txt文件
+将最近只属于User的N条消息合并为txt文件
 
 如果以 `{PREFIX.COMBINATION} + #N` (或附加User) 回复消息M
 则合并消息M之前的N条消息文本 (包含M)
 
 示例:
 1. `{PREFIX.COMBINATION} #10`: 合并最近10条消息为txt文本
-2. `{PREFIX.COMBINATION} #20 @123456`: 合并最近20条消息中UID为123456发送的消息为txt文本
-3. `{PREFIX.COMBINATION} #20 @John`: 合并最近20条消息中用户John(大小写均可)发送的消息为txt文本
+2. `{PREFIX.COMBINATION} #20 @123456`: 合并最近UID为123456的20条消息为txt文本
+3. `{PREFIX.COMBINATION} #20 @John`: 合并最近用户John(大小写均可)的20条消息为txt文本
 如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
 """
 
@@ -50,10 +50,12 @@ async def combine_history(client: Client, message: Message, **kwargs):
     num_history = 0
     if matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)\s+@(\w+)", info["text"]):
         num_history = int(matched.group(1))
+        num_history = min(num_history, MAX_MESSAGE_COMBINATION)
         filter_user = str(matched.group(2))
-        file_name = f"最近{num_history}条消息中{filter_user}的发言.txt"
+        file_name = f"最近{num_history}条{filter_user}的消息.txt"
     elif matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)", info["text"]):
         num_history = int(matched.group(1))
+        num_history = min(num_history, MAX_MESSAGE_COMBINATION)
         filter_user = ""
         file_name = f"最近{num_history}条消息记录.txt"
     else:
@@ -66,7 +68,6 @@ async def combine_history(client: Client, message: Message, **kwargs):
         message = message.reply_to_message
         info = parse_msg(message, silent=True)
         offset_id = info["mid"] + 1  # include the reply message
-
     # set custom chat_id and message_id (useful for debug)
     if matched := re.search(r"cid=(-?\w+)", info["text"], re.IGNORECASE):
         info["cid"] = to_int(matched.group(1))
@@ -74,9 +75,7 @@ async def combine_history(client: Client, message: Message, **kwargs):
         info["mid"] = int(matched.group(1))
         offset_id = info["mid"] + 1  # include this message
 
-    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history)
-    if filter_user:
-        history = [x for x in history if x["full_name"].replace(" ", "").lower() == filter_user.lower() or str(x["uid"]) == filter_user]
+    history = await get_parsed_chat_history(client, info["cid"], offset_id, num_history, filter_user)
     if not history:
         await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
         return
src/config.py
@@ -18,7 +18,8 @@ TEXT_LENGTH = int(os.getenv("TEXT_LENGTH", "4096"))  # Maximum length of text me
 CAPTION_LENGTH = int(os.getenv("CAPTION_LENGTH", "1024"))  # 4096 for Premium user
 MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", "2000")) * 1024 * 1024  # 4000 MB for Premium user
 ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "600"))
-COMBINATION_MAX_HISTORY = int(os.getenv("COMBINATION_MAX_HISTORY", "500"))  # Maximum number of messages to combine
+MAX_MESSAGE_COMBINATION = int(os.getenv("MAX_MESSAGE_COMBINATION", "5000"))  # Maximum number of messages to combine
+MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "1000"))  # Maximum number of messages to summay
 READING_SPEED = int(os.getenv("READING_SPEED", "300"))  # words per minute
 DAILY_MESSAGES = os.getenv("DAILY_MESSAGES", "{}")  # Useful for daily checkin for some services. Should be a json string: '{"chat-1": "msg-1", "chat-2": "msg-2"}'