Commit 2b1f5f0

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-05-18 06:35:48
feat(summary): add ai summary feature
1 parent 1c3bd7b
Changed files (5)
src/ai/texts/contexts.py
@@ -4,12 +4,13 @@ import asyncio
 import base64
 import contextlib
 import hashlib
+import mimetypes
 import time
 from pathlib import Path
 from typing import TYPE_CHECKING, Literal
 
 from anthropic import AsyncAnthropic
-from glom import glom
+from glom import Coalesce, glom
 from google import genai
 from google.genai.types import FileState, Part, UploadFileConfig
 from loguru import logger
@@ -87,7 +88,8 @@ async def single_openai_chat_context(client: Client, message: Message) -> dict:
                 res = await base64_media(client, msg)
                 contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
             elif info["mtype"] == "document":
-                if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+                guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+                if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
                     fpath: str = await client.download_media(msg, media_path)  # type: ignore
                     contexts.append(
                         {
@@ -197,14 +199,15 @@ async def single_openai_response_context(client: Client, message: Message, opena
                     contexts.append({"type": "input_video", "image_url": f"data:video/{res['ext']};base64,{res['base64']}"})
 
             elif info["mtype"] == "document":
-                if info["mime_type"] == "application/pdf":
+                guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+                if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
                     if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
                         contexts.append({"type": "input_file", "file_id": file_id})
                     if not file_id:
                         res = await base64_media(client, msg)
                         contexts.append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
 
-                elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+                elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
                     fpath: str = await client.download_media(msg, media_path)  # type: ignore
                     contexts.append(
                         {
@@ -352,7 +355,8 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
                         parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
                     Path(fpath).unlink(missing_ok=True)
                 elif info["mtype"] == "document":
-                    if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
+                    guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+                    if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
                         fpath: str = await client.download_media(msg, media_path)  # type: ignore
                         parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
                     if Path(info["file_name"]).suffix in extra_markdown_extensions:
@@ -360,7 +364,7 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
                         text = convert_md(fpath)
                         Path(fpath).unlink(missing_ok=True)
                         parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
-                clean_texts = clean_context(info["text"])
+                clean_texts = clean_context(info["html"] or info["text"])
                 if not clean_texts:
                     continue
                 if role == "user" and sender:  # noqa: SIM108
@@ -427,14 +431,15 @@ async def single_anthropic_context(
                     contexts.append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
 
             elif info["mtype"] == "document":
-                if info["mime_type"] == "application/pdf":
+                guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+                if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
                     if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
                         contexts.append({"type": "document", "source": {"type": "file", "file_id": file_id}})
                     if not file_id:
                         res = await base64_media(client, msg)
                         contexts.append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
 
-                elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+                elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
                     fpath: str = await client.download_media(msg, media_path)  # type: ignore
                     contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
 
@@ -477,3 +482,18 @@ async def get_anthropic_file_id(client: Client, message: Message, anthropic: Asy
     except Exception as e:
         logger.error(f"Upload media to Anthropic failed: {e}")
     return ""
+
+
+async def context_bytes(client: Client, message: Message) -> int:
+    chains = [message]
+    while message.reply_to_message:
+        message = message.reply_to_message
+        chains.append(message)
+    messages: list[Message] = []
+    for msg in chains:
+        groups = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
+        messages.extend(groups)
+    size_bytes = 0
+    for m in messages:
+        size_bytes += glom(m, Coalesce("photo.sizes.-1.file_size", "video.file_size", "document.file_size"), default=0)
+    return size_bytes
src/ai/summary.py
@@ -7,23 +7,37 @@ import re
 from pathlib import Path
 
 from loguru import logger
+from pyrogram.client import Client
 from pyrogram.types import Chat, Message
 from pyrogram.types.messages_and_media.message import Str
 
 from ai.main import ai_text_generation
-from config import DB, DOWNLOAD_DIR, PREFIX
+from ai.texts.contexts import context_bytes
+from ai.texts.gemini import gemini_chat_completion
+from ai.texts.models import get_config_by_model_alias
+from ai.texts.openai_response import openai_responses_api
+from ai.utils import deep_merge
+from config import AI, DB, DOWNLOAD_DIR, PREFIX
 from database.r2 import set_cf_r2
+from messages.help import social_media_help
+from messages.parser import parse_msg
+from messages.sender import send2tg
+from messages.utils import equal_prefix, set_reaction, startswith_prefix
 from networking import download_file
 from utils import count_subtitles, rand_number
 
 JSON_SCHEMA = {
-    "title": "Article Summary",
-    "description": "提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
+    "title": "Content Summary",
+    "description": "提炼出资料的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
     "type": "object",
     "properties": {
-        "abstract": {"title": "全文总结", "description": "需涵盖文章核心主题、关键观点和主要结论,用连贯的一段话概括文章的主要内容,避免过于简略。如果内容过长,也可考虑分段总结。", "type": "string"},
+        "abstract": {
+            "title": "全文总结",
+            "description": "需涵盖资料核心主题、关键观点和主要结论,用连贯的一段话概括资料的主要内容,避免过于简略。如果内容过长,也可考虑分段总结。",
+            "type": "string",
+        },
         "sections": {
-            "description": "将文章划分为不同的片段,每个片段需拟定简洁准确的标题,匹配1个相关emoji,并总结该片段的核心内容",
+            "description": "将资料划分为不同的片段,每个片段需拟定简洁准确的标题,匹配1个相关emoji,并总结该片段的核心内容",
             "title": "分片内容",
             "type": "array",
             "items": {
@@ -34,7 +48,7 @@ JSON_SCHEMA = {
                     "summary": {"type": "string", "description": "概括该片段的核心内容"},
                     "start": {
                         "type": ["string", "null"],
-                        "description": "如果文章内容为包含时间戳的文字稿(如播客、视频、音频的转录稿),设置此字段为该片段的开始时间, 格式为(HH:MM:SS或MM:SS)。如果没有时间戳,则无需输出此字段。",
+                        "description": "如果资料内容为包含时间戳的文字稿(如播客、视频、音频的转录稿),设置此字段为该片段的开始时间, 格式为(HH:MM:SS或MM:SS)。如果没有时间戳,则无需输出此字段。",
                     },
                 },
             },
@@ -51,6 +65,41 @@ JSON_SCHEMA = {
 }
 
 
+async def ai_summary(client: Client, message: Message, summary_model_id: str = AI.AI_SUMMARY_MODEL_ALIAS, **kwargs):
+    if not startswith_prefix(message.content, prefix=PREFIX.AI_SUMMARY):
+        return
+    this_msg = message
+    if equal_prefix(message.content, PREFIX.AI_SUMMARY):
+        if not message.reply_to_message:
+            info = parse_msg(message, use_cache=False)
+            await send2tg(client, message, texts=social_media_help(info["cid"], info["ctype"]), **kwargs)
+            return
+        message = message.reply_to_message
+    models = await get_config_by_model_alias(summary_model_id, fallback_to_default=False)
+    models = [x for x in models if x.get("api_type") in ["gemini", "openai_responses"]]  # only support gemini & openai_responses models
+    if not models:
+        return
+    mbytes = await context_bytes(client, message)
+    if mbytes > 40 * 1024 * 1024:  # prefer gemini for large files (40MB)
+        models = sorted(models, key=lambda x: x.get("api_type") == "gemini", reverse=True)
+    await set_reaction(client, this_msg, "👌")
+    for model_config in models:
+        res = {}
+        params = deep_merge(model_config, kwargs, summary_params())
+        if params["api_type"] == "gemini":
+            res = await gemini_chat_completion(client, message, **params)
+        elif params["api_type"] == "openai_responses":
+            res = await openai_responses_api(client, message, **params)
+        if not res.get("texts"):
+            continue
+        texts, mermaid_path = await parse_summary(res["texts"])
+        media = [{"photo": mermaid_path}] if Path(mermaid_path).is_file() else []
+        await send2tg(client, message, texts=texts, media=media, **kwargs)
+        await set_reaction(client, this_msg, "")
+        return
+
+
+
 async def summarize(article: str, reference: str | None = None, model: str = "gemini") -> dict:
     if count_subtitles(article) < 200:  # skip short article
         return {}
@@ -61,65 +110,52 @@ async def summarize(article: str, reference: str | None = None, model: str = "ge
             chat=Chat(id=rand_number()),
             text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{model} {article.strip()}"),
         ),
-        gemini_generate_content_config={
-            "system_instruction": system_prompt(reference),
-            "responseMimeType": "application/json",
-            "responseJsonSchema": JSON_SCHEMA,
-        },
-        openai_responses_config={
-            "instructions": system_prompt(reference),
-            "text": {
-                "format": {
-                    "type": "json_schema",
-                    "name": "ArticleSummary",
-                    "strict": True,
-                    "description": "提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
-                    "schema": JSON_SCHEMA,
-                }
-            },
-        },
-        gemini_append_grounding=False,
-        openai_enable_tool_call=False,
-        openai_append_tool_results=False,
-        silent=True,
+        **summary_params(reference),
     )
     if not res.get("texts", ""):
         return {}
-    res["texts"] = await parse_summary(res["texts"]) or res["texts"]
+    texts, _ = await parse_summary(res["texts"])
+    res["texts"] = texts
     return res
 
 
-async def parse_summary(texts: str) -> str:
+async def parse_summary(texts: str) -> tuple[str, str]:
+    """Parse the summary JSON string.
+
+    Returns:
+        (summary_texts, mermaid_img_path)
+    """
     try:
         summary = json.loads(texts)
         mermaid = beautify_mermaid(summary["mermaid"])
-        mermaid_img = await save_mermaid_jpg_to_r2(mermaid)
+        mermaid_url, mermaid_path = await save_mermaid_jpg_to_r2(mermaid)
         parsed = f"{summary['abstract'].strip()}"
-        if mermaid_img:
-            parsed += f"\n🧠**[思维导图]({mermaid_img})**\n![Mermaid]({mermaid_img})"
+        if mermaid_url:
+            logger.success(f"Mermaid: {mermaid_url}")
+            parsed += f"\n🧠**[思维导图]({mermaid_url})**\n![Mermaid]({mermaid_url})"
         parsed += "\n⚡️**章节速览**"
         for section in summary["sections"]:
             parsed += f"\n{section['emoji']}**{section['title']}**"
             if section.get("start"):
                 parsed += f" [{section['start']}]"
             parsed += f"\n{section['summary']}"
+        logger.success(parsed)
     except Exception as e:
         logger.error(f"Error parsing summary: {e}")
-        return ""
-    return parsed
+        return texts, ""
+    return parsed, mermaid_path
 
 
 def system_prompt(reference: str | None = None) -> str:
-    prompt = "你是一位专业的文章总结大师,任务是基于用户提供的文本,提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。"
+    prompt = "你是一位专业的内容总结大师,任务是基于用户提供的资料提炼出核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。"
     if reference:
         prompt += f"\n{reference}"
-    return prompt.strip()
+    return prompt.strip() + mermaid_syntax()
 
 
 def beautify_mermaid(mermaid: str) -> str:
     def replace(s: str) -> str:
         s = s.replace("\n", "<br/>")
-        s = s.replace(" ", "<br/>")
         s = s.replace('"', "&quot;")
         s = s.replace("'", "&apos;")
         s = s.replace("[", "&lsqb;")
@@ -137,13 +173,257 @@ def beautify_mermaid(mermaid: str) -> str:
     return f"---\nconfig:\n  theme: neo\n  look: neo\n---\n{mermaid.strip()}"
 
 
-async def save_mermaid_jpg_to_r2(mermaid: str) -> str:
+async def save_mermaid_jpg_to_r2(mermaid: str) -> tuple[str, str]:
+    """Save Mermaid image to R2.
+
+    Returns:
+        (image_url, local_path)
+    """
     b64_str = base64.urlsafe_b64encode(mermaid.encode("utf-8")).decode("ascii")
     save_path = Path(DOWNLOAD_DIR) / f"{hashlib.sha256(mermaid.encode()).hexdigest()}.jpg"
     await download_file(f"https://mermaid.ink/img/{b64_str}?type=jpeg&theme=forest&width=2160", path=save_path, suffix=".jpg")
     if save_path.is_file():
         r2_key = f"TTL/365d/{save_path.name}"
         await set_cf_r2(r2_key, data=save_path.read_bytes(), mime_type="image/jpeg", silent=True)
-        save_path.unlink(missing_ok=True)
-        return f"{DB.CF_R2_PUBLIC_URL}/{r2_key}"
-    return ""
+        return f"{DB.CF_R2_PUBLIC_URL}/{r2_key}", save_path.as_posix()
+    return "", ""
+
+
+def summary_params(reference: str | None = None) -> dict:
+    return {
+        "gemini_generate_content_config": {
+            "system_instruction": system_prompt(reference),
+            "responseMimeType": "application/json",
+            "responseJsonSchema": JSON_SCHEMA,
+        },
+        "openai_responses_config": {
+            "instructions": system_prompt(reference),
+            "text": {
+                "format": {
+                    "type": "json_schema",
+                    "name": "ContentSummary",
+                    "strict": True,
+                    "description": "提炼出资料的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
+                    "schema": JSON_SCHEMA,
+                }
+            },
+        },
+        "gemini_append_grounding": False,
+        "openai_enable_tool_call": False,
+        "openai_append_tool_results": False,
+        "silent": True,
+    }
+
+
+def mermaid_syntax() -> str:
+    return """
+# Mermaid Flowcharts - Basic Syntax
+
+Flowcharts are composed of **nodes** (geometric shapes) and **edges** (arrows or lines). The Mermaid code defines how nodes and edges are made and accommodates different arrow types, multi-directional arrows, and any linking to and from subgraphs.
+
+### A node (default)
+
+```mermaid
+flowchart LR
+    id
+```
+
+```note
+The id is what is displayed in the box.
+```
+
+### A node with text
+
+It is also possible to set text in the box that differs from the id. If this is done several times, it is the last text
+found for the node that will be used. Also if you define edges for the node later on, you can omit text definitions. The
+one previously defined will be used when rendering the box.
+
+```mermaid
+flowchart LR
+    id1[This is the text in the box]
+```
+
+## Node shapes
+
+### A node with round edges
+
+```mermaid
+flowchart LR
+    id1(This is the text in the box)
+```
+
+### A stadium-shaped node
+
+```mermaid
+flowchart LR
+    id1([This is the text in the box])
+```
+
+### A node in a subroutine shape
+
+```mermaid
+flowchart LR
+    id1[[This is the text in the box]]
+```
+
+### A node in a cylindrical shape
+
+```mermaid
+flowchart LR
+    id1[(Database)]
+```
+
+### A node in the form of a circle
+
+```mermaid
+flowchart LR
+    id1((This is the text in the circle))
+```
+
+### A node in an asymmetric shape
+
+```mermaid
+flowchart LR
+    id1>This is the text in the box]
+```
+
+## Links between nodes
+
+Nodes can be connected with links/edges. It is possible to have different types of links or attach a text string to a link.
+
+### A link with arrow head
+
+```mermaid
+flowchart LR
+    A-->B
+```
+
+### An open link
+
+```mermaid
+flowchart LR
+    A --- B
+```
+
+### Text on links
+
+```mermaid
+flowchart LR
+    A-- This is the text! ---B
+```
+
+or
+
+```mermaid
+flowchart LR
+    A---|This is the text|B
+```
+
+### A link with arrow head and text
+
+```mermaid
+flowchart LR
+    A-->|text|B
+```
+
+or
+
+```mermaid
+flowchart LR
+    A-- text -->B
+```
+
+### Dotted link
+
+```mermaid
+flowchart LR
+   A-.->B;
+```
+
+### Dotted link with text
+
+```mermaid
+flowchart LR
+   A-. text .-> B
+```
+
+### Thick link
+
+```mermaid
+flowchart LR
+   A ==> B
+```
+
+### Thick link with text
+
+```mermaid
+flowchart LR
+   A == text ==> B
+```
+
+### An invisible link
+
+This can be a useful tool in some instances where you want to alter the default positioning of a node.
+
+```mermaid
+flowchart LR
+    A ~~~ B
+```
+
+### Chaining of links
+
+It is possible declare many links in the same line as per below:
+
+```mermaid
+flowchart LR
+   A -- text --> B -- text2 --> C
+```
+
+It is also possible to declare multiple nodes links in the same line as per below:
+
+```mermaid
+flowchart LR
+   a --> b & c--> d
+```
+
+You can then describe dependencies in a very expressive way. Like the one-liner below:
+
+```mermaid
+flowchart TB
+    A & B--> C & D
+```
+
+If you describe the same diagram using the basic syntax, it will take four lines. A
+word of warning, one could go overboard with this making the flowchart harder to read in
+markdown form. The Swedish word `lagom` comes to mind. It means, not too much and not too little.
+This goes for expressive syntaxes as well.
+
+```mermaid
+flowchart TB
+    A --> C
+    A --> D
+    B --> C
+    B --> D
+```
+
+## New arrow types
+
+There are new types of arrows supported:
+
+- circle edge
+- cross edge
+
+### Circle edge example
+
+```mermaid
+flowchart LR
+    A --o B
+```
+
+### Cross edge example
+
+```mermaid
+flowchart LR
+    A --x B
+```
+"""
src/messages/help.py
@@ -4,10 +4,10 @@ from config import PREFIX
 from permission import check_service
 
 
-def social_media_help(chat_id: int | str, ctype: str, prefix: str):
+def social_media_help(chat_id: int | str, ctype: str):
     """Get the help message for social media preview."""
     permission = check_service(cid=chat_id, ctype=ctype)
-    msg = f"🔗**链接解析**: {prefix}\n🔄使用 `/retry` 回复消息强制重试"
+    msg = f"🔗**链接解析**: {PREFIX.SOCIAL_MEDIA}\n🔄使用 `/retry` 回复消息强制重试"
     if permission["twitter"]:
         msg += "\n🕊推特"
     if permission["weibo"]:
src/messages/main.py
@@ -8,6 +8,7 @@ from pyrogram.types import Message
 
 from ai.chat_summary import ai_chat_summary
 from ai.main import ai_image_generation, ai_text_generation, ai_video_generation
+from ai.summary import ai_summary
 from asr.voice_recognition import voice_to_text
 from bridge.ocr import send_to_ocr_bridge
 from config import FAVORITE, PREFIX, PROXY
@@ -104,6 +105,7 @@ async def process_message(
         await ai_text_generation(client, message, **kwargs)  # /ai
         await ai_image_generation(client, message, **kwargs)  # /gen
         await ai_video_generation(client, message, **kwargs)  # /gvid
+        await ai_summary(client, message, **kwargs)  # /summary
     if asr:
         await voice_to_text(client, message, **kwargs)  # /asr
     if audio_extract:
@@ -123,7 +125,7 @@ async def process_message(
     if history:
         await query_chat_history(client, message, **kwargs)  # /history
     if summary:
-        await ai_chat_summary(client, message, **kwargs)  # /summary
+        await ai_chat_summary(client, message, **kwargs)  # /chatsum
     if danmu:
         await query_danmu(client, message, **kwargs)  # /danmu
     if favorite:
@@ -233,7 +235,7 @@ async def preview_social_media(
         # without reply, send docs if message only contains prefix command
         if not message.reply_to_message:
             await delete_message(message)
-            docs = social_media_help(info["cid"], info["ctype"], PREFIX.SOCIAL_MEDIA)
+            docs = social_media_help(info["cid"], info["ctype"])
             helps = await send2tg(client, message, texts=docs, **kwargs)
             await asyncio.sleep(30)
             return await delete_message(helps)
src/config.py
@@ -405,6 +405,7 @@ class AI:
     OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
     TOOL_CALL_MODEL_ALIAS = os.getenv("AI_TOOL_CALL_MODEL_ALIAS", "tool-call")
+    AI_SUMMARY_MODEL_ALIAS = os.getenv("AI_SUMMARY_MODEL_ALIAS", "gemini")
     PODCAST_SUMMARY_MODEL_ALIAS = os.getenv("PODCAST_SUMMARY_MODEL_ALIAS", "podcast-summary")
     SUBTITLE_SUMMARY_MODEL_ALIAS = os.getenv("SUBTITLE_SUMMARY_MODEL_ALIAS", "subtitle-summary")
     CHAT_SUMMARY_MODEL_ALIAS = os.getenv("CHAT_SUMMARY_MODEL_ALIAS", "chat-summary")