Commit 58d8c91

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-01-28 14:57:36
feat: add ai summary feature
1 parent 707bc3e
src/llm/gpt.py
@@ -41,20 +41,12 @@ def is_gpt_conversation(message: Message) -> bool:
 
 
 @cache.memoize(ttl=60)
-async def gpt_response(
-    client: Client,
-    message: Message,
-    # contexts: list[dict] | None = None,
-    **kwargs,
-):
+async def gpt_response(client: Client, message: Message, **kwargs):
     """Get GPT response from Various API.
 
-    message and contexts must be provided at least one.
-
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
-        contexts (list[dict]): The conversation contexts in OpenAI format.
     """
     if not ENABLE.GPT:
         return
src/llm/response.py
@@ -9,7 +9,7 @@ from messages.progress import modify_progress
 
 async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
     """Get GPT response for text model."""
-    response = f"🤖{config['friendly_name']}对话失败, 请稍后重试."
+    response = f"🤖{config['friendly_name']}请求失败, 请稍后重试."
     logger.trace(contexts)
     try:
         openai = AsyncOpenAI(
@@ -26,7 +26,7 @@ async def get_gpt_response(config: dict, contexts: list[dict], **kwargs) -> str:
         if choices := resp.model_dump().get("choices", []):
             response = choices[0].get("message", {}).get("content")
     except Exception as e:
-        error = f"🤖{config['friendly_name']}对话失败, 请稍后重试.\n{e}"
+        error = f"🤖{config['friendly_name']}请求失败, 请稍后重试.\n{e}"
         logger.error(f"GPT request failed: {e}")
         await modify_progress(text=error, force_update=True, **kwargs)
         return error
src/llm/summary.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import base64
+import json
+import re
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import COMBINATION_MAX_HISTORY, ENABLE, GPT, PREFIX, cache
+from llm.response import get_gpt_response
+from llm.utils import fix_doubao
+from messages.chat_history import get_parsed_chat_history
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import equal_prefix
+
+if TYPE_CHECKING:
+    from io import BytesIO
+
+HELP = f"""🤖**GPT总结历史消息** (最多{COMBINATION_MAX_HISTORY}条)
+当前模型:
+- 文本模型: **{GPT.TEXT_MODEL_NAME}**
+- 图片模型: **{GPT.IMAGE_MODEL_NAME}**
+
+使用说明:
+1. `{PREFIX.AI_SUMMARY} + #N`
+GPT总结最近的N条历史消息
+
+2. `{PREFIX.AI_SUMMARY} + #N + @User`
+GPT总结最近的N条历史消息中只属于User发送的消息
+
+如果以 `{PREFIX.AI_SUMMARY} + #N` (或附加User) 回复消息M
+则总结消息M之前的N条消息文本 (包含M)
+
+示例:
+1. `{PREFIX.AI_SUMMARY} #10`: 总结最近的10条历史消息
+2. `{PREFIX.AI_SUMMARY} #20 @123456`: 总结最近的20条历史消息中UID为123456的消息
+3. `{PREFIX.AI_SUMMARY} #20 @John`: 总结最近20条消息中用户John(大小写均可)发送的消息
+如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
+"""
+
+
+@cache.memoize(ttl=60)
+async def ai_summary(client: Client, message: Message, **kwargs):
+    """GPT summary of the message history.
+
+    Args:
+        client (Client): The Pyrogram client.
+        message (Message): The trigger message object.
+    """
+    if not ENABLE.AI_SUMMARY:
+        return
+    # send docs if message == "/summary"
+    if equal_prefix(message.text, prefix=[PREFIX.AI_SUMMARY]):
+        await send2tg(client, message, texts=HELP, **kwargs)
+        return
+
+    # get the number of messages to combine
+    info = parse_msg(message)
+    num_history = 0
+    if matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)\s+@(\w+)", info["text"]):
+        num_history = int(matched.group(1))
+        filter_user = str(matched.group(2))
+    elif matched := re.match(r"^" + PREFIX.AI_SUMMARY + r"\s+#(\d+)", info["text"]):
+        num_history = int(matched.group(1))
+        filter_user = ""
+    else:
+        return
+
+    # reply a message with /summary
+    offset_id = message.id
+    if message.reply_to_message:
+        message = message.reply_to_message
+        offset_id = message.id + 1  # include the reply message
+
+    history = await get_parsed_chat_history(client, message, offset_id, num_history)
+    # filter by user
+    if filter_user:
+        history = [info for info in history if info["full_name"].replace(" ", "").lower() == filter_user.lower() or str(info["uid"]) == filter_user]
+
+    if not history:
+        await send2tg(client, message, texts=f"最近{num_history}条消息中未找到符合条件的消息", **kwargs)
+        return
+
+    model_conf = get_summay_model(history)
+    contexts = await get_contexts(client, history)
+    if model_conf["friendly_name"].startswith("豆包"):
+        contexts = fix_doubao(contexts)
+    msg = f"🤖{model_conf['friendly_name']}: 总结中..."
+    if kwargs.get("show_progress"):
+        res = await send2tg(client, message, texts=msg, **kwargs)
+        kwargs["progress"] = res[0]
+    response = await get_gpt_response(model_conf, contexts, **kwargs)
+    logger.debug(response)
+    await send2tg(client, message, texts=response, **kwargs)
+    await modify_progress(del_status=True, **kwargs)
+
+
+def get_summay_model(history: list[dict]) -> dict:
+    """Get the model for the summary."""
+    models = {"text": GPT.TEXT_MODEL, "image": GPT.IMAGE_MODEL}
+    model_names = {"text": GPT.TEXT_MODEL_NAME, "image": GPT.IMAGE_MODEL_NAME}
+    timeouts = {"text": GPT.TEXT_TIMEOUT, "image": GPT.IMAGE_TIMEOUT}
+    apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY}
+    urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL}
+    model_type = "image" if "photo" in {x["mtype"] for x in history} else "text"
+    model = models[model_type]
+    config = {
+        "model": model,
+        "friendly_name": model_names[model_type],
+        "timeout": round(float(timeouts[model_type])),
+        "base_url": urls[model_type],
+        "key": apis[model_type],
+        "temperature": float(GPT.TEMPERATURE),
+    }
+    logger.trace(config)
+    return config
+
+
+async def get_contexts(client: Client, history: list[dict]) -> list[dict]:
+    """Get GPT contexts based on parsed chat history."""
+    contexts = [
+        {
+            "role": "system",  # system prompt
+            "content": [
+                {
+                    "type": "text",
+                    "text": """总结在线休闲讨论组的聊天记录, 识别关键主题、争议话题以及重要观点。提供一个简明的总结, 保留原始意图和上下文。如有必要, 引用原始用户名和时间戳, 并使用清晰的语言。
+# 步骤
+1. 阅读聊天记录: 仔细查看对话内容, 了解讨论的流程和上下文。
+2. 识别关键主题: 提取整个聊天中讨论的主要话题。
+3. 突出争议话题: 记录任何分歧或意见不同的地方。
+4. 识别重要观点: 捕捉参与者提出的重要观点或论点。
+5. 保留意图和上下文: 确保总结反映对话的原始意义和上下文。
+6. 引用用户名和时间戳: 在适当情况下, 引用用户名和时间戳以为某些陈述提供上下文。
+7. 撰写总结: 以简洁的语言编写总结, 同时包含必要的引用。
+
+# 输出格式
+- 使用中文撰写总结。
+- 简明扼要地总结聊天记录的内容。
+- 在必要时引用用户名和时间戳。
+- 保持清晰和简洁的表达。
+
+# 示例
+- 输入: [包含用户名和时间戳的聊天记录片段]
+- 输出:
+  [10:23:30] Alice 提出关于气候变化的话题, 重点讨论其影响。
+  [11:00:30] Bob 表示反对, 引用了相反的证据。
+  [11:30:00] Charlie 提出了一个新的项目想法, 引起了大家的兴趣。
+  [12:00:00] 大家讨论了项目的潜在挑战和机会。最终, 决定下次会议继续讨论这个项目。
+""",
+                }
+            ],
+        }
+    ]
+    user_contexts = []
+    for info in history:
+        if info["mtype"] == "text" and info["text"]:
+            content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
+            user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
+            continue
+        if info["mtype"] == "photo":
+            content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": "[image]如下"}
+            user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
+            res: BytesIO = await client.download_media(info["file_id"], in_memory=True)  # type: ignore
+            ext = Path(res.name).suffix.removeprefix(".").replace("jpg", "jpeg")
+            b64 = base64.b64encode(res.getvalue()).decode("utf-8")
+            user_contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}})
+            if info["text"]:
+                content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": info["text"]}
+                user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
+        else:
+            content = {"user": info["full_name"], "time": f"{info['datetime']:%H:%M:%S}", "message": f"[{info['mtype']}] {info['text']}".strip()}
+            user_contexts.append({"type": "text", "text": json.dumps(content, ensure_ascii=False)})
+    contexts.append({"role": "user", "content": user_contexts})
+    logger.trace(contexts)
+    return contexts
src/messages/chat_history.py
@@ -4,17 +4,27 @@
 from pyrogram.client import Client
 from pyrogram.types import Message
 
+from config import cache
 from messages.parser import parse_msg
 
 
-async def get_chat_history(client: Client, message: Message, offset_id: int, num_history: int = 0) -> list[dict]:
-    """Get given number of chat history in parserd json format."""
-    if num_history <= 0:
+@cache.memoize(ttl=10)
+async def get_history_messages(client: Client, message: Message, offset_id: int, num: int = 0) -> list[Message]:
+    """Get given number of chat history from old to new."""
+    if num <= 0:
         return []
     history = []
-    async for msg in client.get_chat_history(chat_id=message.chat.id, offset_id=offset_id, limit=num_history):  # type: ignore
+    async for msg in client.get_chat_history(chat_id=message.chat.id, offset_id=offset_id, limit=num):  # type: ignore
         if msg.empty:
             continue
-        info = parse_msg(msg, silent=True)
-        history.append(info)
+        history.append(msg)
     return history[::-1]
+
+
+@cache.memoize(ttl=10)
+async def get_parsed_chat_history(client: Client, message: Message, offset_id: int, num: int = 0) -> list[dict]:
+    """Get given number of chat history in parserd json format."""
+    if num <= 0:
+        return []
+    history = await get_history_messages(client, message, offset_id, num)
+    return [parse_msg(msg, silent=True) for msg in history]
src/others/combine_history.py
@@ -7,18 +7,19 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import COMBINATION_MAX_HISTORY, ENABLE, PREFIX, READING_SPEED
-from messages.chat_history import get_chat_history
+from messages.chat_history import get_parsed_chat_history
+from messages.parser import parse_msg
 from messages.sender import send2tg
 from messages.utils import equal_prefix, get_reply_to, startswith_prefix
 from utils import to_int
 
 HELP = f"""
-💬**合并对话历史**
+💬**合并对话历史** (最多{COMBINATION_MAX_HISTORY}条)
 使用说明:
-1. `{PREFIX.COMBINATION} + #N` (最多{COMBINATION_MAX_HISTORY}条)
+1. `{PREFIX.COMBINATION} + #N`
 将最近的N条消息文本合并为txt文件
 
-2. `{PREFIX.COMBINATION} + #N + @User` (最多{COMBINATION_MAX_HISTORY}条)
+2. `{PREFIX.COMBINATION} + #N + @User`
 将最近的N条消息中只属于User发送的消息合并为txt文件
 
 如果以 `{PREFIX.COMBINATION} + #N` (或附加User) 回复消息M
@@ -26,9 +27,9 @@ HELP = f"""
 
 示例:
 1. `{PREFIX.COMBINATION} #10`: 合并最近10条消息为txt文本
-2. `{PREFIX.COMBINATION} #20 @123456`: 合并最近20条消息中UID为12345678发送的消息为txt文本
-3. `{PREFIX.COMBINATION} #20 @John`: 合并最近20条消息中用户名为John(大小写均可)发送的消息为txt文本
-如果用户名中有空格, 请去除空格。例如: 想指定用户名为John Doe请使用 `@JohnDoe`
+2. `{PREFIX.COMBINATION} #20 @123456`: 合并最近20条消息中UID为123456发送的消息为txt文本
+3. `{PREFIX.COMBINATION} #20 @John`: 合并最近20条消息中用户John(大小写均可)发送的消息为txt文本
+如果用户名中有空格, 请去除空格。例如: 想指定用户为John Doe请使用 `@JohnDoe`
 """
 
 
@@ -44,12 +45,13 @@ async def combine_history(client: Client, message: Message, **kwargs):
         return
 
     # get the number of messages to combine
+    info = parse_msg(message)
     num_history = 0
-    if matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)\s+@(\w+)", message.text):
+    if matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)\s+@(\w+)", info["text"]):
         num_history = int(matched.group(1))
         filter_user = str(matched.group(2))
         file_name = f"最近{num_history}条消息中{filter_user}的发言.txt"
-    elif matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)", message.text):
+    elif matched := re.match(r"^" + PREFIX.COMBINATION + r"\s+#(\d+)", info["text"]):
         num_history = int(matched.group(1))
         filter_user = ""
         file_name = f"最近{num_history}条消息记录.txt"
@@ -62,7 +64,7 @@ async def combine_history(client: Client, message: Message, **kwargs):
     if message.reply_to_message:
         message = message.reply_to_message
         offset_id = message.id + 1  # include the reply message
-    history = await get_chat_history(client, message, offset_id, num_history)
+    history = await get_parsed_chat_history(client, message, offset_id, num_history)
     if filter_user:
         history = [x for x in history if x["full_name"].replace(" ", "").lower() == filter_user.lower() or str(x["uid"]) == filter_user]
     if not history:
src/config.py
@@ -24,6 +24,7 @@ DAILY_MESSAGES = os.getenv("DAILY_MESSAGES", "{}")  # Useful for daily checkin f
 
 
 class ENABLE:
+    AI_SUMMARY = os.getenv("ENABLE_AI_SUMMARY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     ASR = os.getenv("ENABLE_ASR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     AUDIO = os.getenv("ENABLE_AUDIO", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     BILIBILI = os.getenv("ENABLE_BILIBILI", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -50,6 +51,7 @@ class ENABLE:
 
 class PREFIX:
     MAIN: ClassVar[list[str]] = [x.strip().lower() for x in os.getenv("PREFIX_MAIN", "/benny,/dl,!dl,!dl,!下载,!下载").split(",")]
+    AI_SUMMARY = os.getenv("PREFIX_AI_SUMMARY", "/summary").lower()
     ASR = os.getenv("PREFIX_ASR", "/asr").lower()
     AUDIO = os.getenv("PREFIX_AUDIO", "/audio").lower()
     GPT = os.getenv("PREFIX_GPT", "/ai").lower()
src/handler.py
@@ -12,6 +12,7 @@ from bridge.ocr import send_to_ocr_bridge
 from config import ENABLE, PREFIX, PROXY, cache
 from database import del_db
 from llm.gpt import gpt_response
+from llm.summary import ai_summary
 from messages.parser import parse_msg
 from messages.sender import send2tg
 from messages.utils import equal_prefix, startswith_prefix
@@ -44,6 +45,7 @@ async def handle_utilities(
     subtitle: bool = True,
     wget: bool = True,
     ocr: bool = True,
+    summary: bool = True,
     raw_img: bool = False,
     show_progress: bool = True,
     detail_progress: bool = False,
@@ -65,6 +67,7 @@ async def handle_utilities(
         subtitle (bool, optional): Enable YouTube subtitle. Defaults to True.
         wget (bool, optional): Enable WGET. Defaults to True.
         ocr (bool, optional): Enable OCR. Defaults to True.
+        summary (bool, optional): Enable AI summary. Defaults to True.
         raw_img (bool, optional): Enable convert raw image. Defaults to False.
         show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
         detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
@@ -84,6 +87,8 @@ async def handle_utilities(
         await download_url_in_message(client, message, **kwargs)  # /wget
     if ocr:
         await send_to_ocr_bridge(client, message)  # /ocr
+    if summary:
+        await ai_summary(client, message)  # /summary
     if raw_img:
         await convert_raw_img_file(client, message, **kwargs)
 
@@ -275,6 +280,8 @@ def get_social_media_help(cmd_prefix: list[str] | None = None, ignore_prefix: li
         msg += f"\n🔤**图片转文字**: `{PREFIX.OCR}` 回复图片消息"
     if ENABLE.COMBINATION:
         msg += f"\n💬**合并历史**: `{PREFIX.COMBINATION} #N` 合并最近N条对话历史"
+    if ENABLE.AI_SUMMARY:
+        msg += f"\n🤖**总结历史**: `{PREFIX.AI_SUMMARY} #N` 总结最近N条对话历史"
     msg += "\n\n单独发送每个命令前缀本身可查看该命令详细使用说明"
     return msg