Commit 2b1f5f0
Changed files (5)
src
src/ai/texts/contexts.py
@@ -4,12 +4,13 @@ import asyncio
import base64
import contextlib
import hashlib
+import mimetypes
import time
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from anthropic import AsyncAnthropic
-from glom import glom
+from glom import Coalesce, glom
from google import genai
from google.genai.types import FileState, Part, UploadFileConfig
from loguru import logger
@@ -87,7 +88,8 @@ async def single_openai_chat_context(client: Client, message: Message) -> dict:
res = await base64_media(client, msg)
contexts.append({"type": "image_url", "image_url": {"url": f"data:image/{res['ext']};base64,{res['base64']}"}})
elif info["mtype"] == "document":
- if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
fpath: str = await client.download_media(msg, media_path) # type: ignore
contexts.append(
{
@@ -197,14 +199,15 @@ async def single_openai_response_context(client: Client, message: Message, opena
contexts.append({"type": "input_video", "image_url": f"data:video/{res['ext']};base64,{res['base64']}"})
elif info["mtype"] == "document":
- if info["mime_type"] == "application/pdf":
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
if media_send_as == "file_id" and (file_id := await get_openai_file_id(client, msg, openai_params, info["mtype"])):
contexts.append({"type": "input_file", "file_id": file_id})
if not file_id:
res = await base64_media(client, msg)
contexts.append({"type": "input_file", "file_data": f"data:application/pdf;base64,{res['base64']}", "filename": info["file_name"]})
- elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+ elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
fpath: str = await client.download_media(msg, media_path) # type: ignore
contexts.append(
{
@@ -352,7 +355,8 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
parts.append(Part.from_uri(file_uri=upload.uri, mime_type=upload.mime_type))
Path(fpath).unlink(missing_ok=True)
elif info["mtype"] == "document":
- if info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in txt_extensions:
fpath: str = await client.download_media(msg, media_path) # type: ignore
parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"))
if Path(info["file_name"]).suffix in extra_markdown_extensions:
@@ -360,7 +364,7 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
text = convert_md(fpath)
Path(fpath).unlink(missing_ok=True)
parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
- clean_texts = clean_context(info["text"])
+ clean_texts = clean_context(info["html"] or info["text"])
if not clean_texts:
continue
if role == "user" and sender: # noqa: SIM108
@@ -427,14 +431,15 @@ async def single_anthropic_context(
contexts.append({"type": "image", "source": {"type": "base64", "media_type": f"image/{res['ext']}", "data": res["base64"]}})
elif info["mtype"] == "document":
- if info["mime_type"] == "application/pdf":
+ guessed_mime, _ = mimetypes.guess_type(info["file_name"])
+ if info["mime_type"] == "application/pdf" or guessed_mime == "application/pdf":
if media_send_as == "file_id" and (file_id := await get_anthropic_file_id(client, msg, anthropic, cache_hour)):
contexts.append({"type": "document", "source": {"type": "file", "file_id": file_id}})
if not file_id:
res = await base64_media(client, msg)
contexts.append({"type": "document", "source": {"type": "base64", "media_type": "application/pdf", "data": res["base64"]}})
- elif info["mime_type"].startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
+ elif info["mime_type"].startswith("text/") or str(guessed_mime).startswith("text/") or Path(info["file_name"]).suffix in extra_txt_extensions:
fpath: str = await client.download_media(msg, media_path) # type: ignore
contexts.append({"type": "text", "text": f"[filename]: {info['file_name']}\n[file content]:\n{read_text(fpath).strip()}"})
@@ -477,3 +482,18 @@ async def get_anthropic_file_id(client: Client, message: Message, anthropic: Asy
except Exception as e:
logger.error(f"Upload media to Anthropic failed: {e}")
return ""
+
+
+async def context_bytes(client: Client, message: Message) -> int:
+ chains = [message]
+ while message.reply_to_message:
+ message = message.reply_to_message
+ chains.append(message)
+ messages: list[Message] = []
+ for msg in chains:
+ groups = await client.get_media_group(msg.chat.id, msg.id) if msg.media_group_id else [msg]
+ messages.extend(groups)
+ size_bytes = 0
+ for m in messages:
+ size_bytes += glom(m, Coalesce("photo.sizes.-1.file_size", "video.file_size", "document.file_size"), default=0)
+ return size_bytes
src/ai/summary.py
@@ -7,23 +7,37 @@ import re
from pathlib import Path
from loguru import logger
+from pyrogram.client import Client
from pyrogram.types import Chat, Message
from pyrogram.types.messages_and_media.message import Str
from ai.main import ai_text_generation
-from config import DB, DOWNLOAD_DIR, PREFIX
+from ai.texts.contexts import context_bytes
+from ai.texts.gemini import gemini_chat_completion
+from ai.texts.models import get_config_by_model_alias
+from ai.texts.openai_response import openai_responses_api
+from ai.utils import deep_merge
+from config import AI, DB, DOWNLOAD_DIR, PREFIX
from database.r2 import set_cf_r2
+from messages.help import social_media_help
+from messages.parser import parse_msg
+from messages.sender import send2tg
+from messages.utils import equal_prefix, set_reaction, startswith_prefix
from networking import download_file
from utils import count_subtitles, rand_number
JSON_SCHEMA = {
- "title": "Article Summary",
- "description": "提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
+ "title": "Content Summary",
+ "description": "提炼出资料的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
"type": "object",
"properties": {
- "abstract": {"title": "全文总结", "description": "需涵盖文章核心主题、关键观点和主要结论,用连贯的一段话概括文章的主要内容,避免过于简略。如果内容过长,也可考虑分段总结。", "type": "string"},
+ "abstract": {
+ "title": "全文总结",
+ "description": "需涵盖资料核心主题、关键观点和主要结论,用连贯的一段话概括资料的主要内容,避免过于简略。如果内容过长,也可考虑分段总结。",
+ "type": "string",
+ },
"sections": {
- "description": "将文章划分为不同的片段,每个片段需拟定简洁准确的标题,匹配1个相关emoji,并总结该片段的核心内容",
+ "description": "将资料划分为不同的片段,每个片段需拟定简洁准确的标题,匹配1个相关emoji,并总结该片段的核心内容",
"title": "分片内容",
"type": "array",
"items": {
@@ -34,7 +48,7 @@ JSON_SCHEMA = {
"summary": {"type": "string", "description": "概括该片段的核心内容"},
"start": {
"type": ["string", "null"],
- "description": "如果文章内容为包含时间戳的文字稿(如播客、视频、音频的转录稿),设置此字段为该片段的开始时间, 格式为(HH:MM:SS或MM:SS)。如果没有时间戳,则无需输出此字段。",
+ "description": "如果资料内容为包含时间戳的文字稿(如播客、视频、音频的转录稿),设置此字段为该片段的开始时间, 格式为(HH:MM:SS或MM:SS)。如果没有时间戳,则无需输出此字段。",
},
},
},
@@ -51,6 +65,41 @@ JSON_SCHEMA = {
}
+async def ai_summary(client: Client, message: Message, summary_model_id: str = AI.AI_SUMMARY_MODEL_ALIAS, **kwargs):
+ if not startswith_prefix(message.content, prefix=PREFIX.AI_SUMMARY):
+ return
+ this_msg = message
+ if equal_prefix(message.content, PREFIX.AI_SUMMARY):
+ if not message.reply_to_message:
+ info = parse_msg(message, use_cache=False)
+ await send2tg(client, message, texts=social_media_help(info["cid"], info["ctype"]), **kwargs)
+ return
+ message = message.reply_to_message
+ models = await get_config_by_model_alias(summary_model_id, fallback_to_default=False)
+ models = [x for x in models if x.get("api_type") in ["gemini", "openai_responses"]] # only support gemini & openai_responses models
+ if not models:
+ return
+ mbytes = await context_bytes(client, message)
+ if mbytes > 40 * 1024 * 1024: # prefer gemini for large files (40MB)
+ models = sorted(models, key=lambda x: x.get("api_type") == "gemini", reverse=True)
+ await set_reaction(client, this_msg, "👌")
+ for model_config in models:
+ res = {}
+ params = deep_merge(model_config, kwargs, summary_params())
+ if params["api_type"] == "gemini":
+ res = await gemini_chat_completion(client, message, **params)
+ elif params["api_type"] == "openai_responses":
+ res = await openai_responses_api(client, message, **params)
+ if not res.get("texts"):
+ continue
+ texts, mermaid_path = await parse_summary(res["texts"])
+ media = [{"photo": mermaid_path}] if Path(mermaid_path).is_file() else []
+ await send2tg(client, message, texts=texts, media=media, **kwargs)
+ await set_reaction(client, this_msg, "")
+ return
+
+
+
async def summarize(article: str, reference: str | None = None, model: str = "gemini") -> dict:
if count_subtitles(article) < 200: # skip short article
return {}
@@ -61,65 +110,52 @@ async def summarize(article: str, reference: str | None = None, model: str = "ge
chat=Chat(id=rand_number()),
text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{model} {article.strip()}"),
),
- gemini_generate_content_config={
- "system_instruction": system_prompt(reference),
- "responseMimeType": "application/json",
- "responseJsonSchema": JSON_SCHEMA,
- },
- openai_responses_config={
- "instructions": system_prompt(reference),
- "text": {
- "format": {
- "type": "json_schema",
- "name": "ArticleSummary",
- "strict": True,
- "description": "提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
- "schema": JSON_SCHEMA,
- }
- },
- },
- gemini_append_grounding=False,
- openai_enable_tool_call=False,
- openai_append_tool_results=False,
- silent=True,
+ **summary_params(reference),
)
if not res.get("texts", ""):
return {}
- res["texts"] = await parse_summary(res["texts"]) or res["texts"]
+ texts, _ = await parse_summary(res["texts"])
+ res["texts"] = texts
return res
-async def parse_summary(texts: str) -> str:
+async def parse_summary(texts: str) -> tuple[str, str]:
+ """Parse the summary JSON string.
+
+ Returns:
+ (summary_texts, mermaid_img_path)
+ """
try:
summary = json.loads(texts)
mermaid = beautify_mermaid(summary["mermaid"])
- mermaid_img = await save_mermaid_jpg_to_r2(mermaid)
+ mermaid_url, mermaid_path = await save_mermaid_jpg_to_r2(mermaid)
parsed = f"{summary['abstract'].strip()}"
- if mermaid_img:
- parsed += f"\n🧠**[思维导图]({mermaid_img})**\n"
+ if mermaid_url:
+ logger.success(f"Mermaid: {mermaid_url}")
+ parsed += f"\n🧠**[思维导图]({mermaid_url})**\n"
parsed += "\n⚡️**章节速览**"
for section in summary["sections"]:
parsed += f"\n{section['emoji']}**{section['title']}**"
if section.get("start"):
parsed += f" [{section['start']}]"
parsed += f"\n{section['summary']}"
+ logger.success(parsed)
except Exception as e:
logger.error(f"Error parsing summary: {e}")
- return ""
- return parsed
+ return texts, ""
+ return parsed, mermaid_path
def system_prompt(reference: str | None = None) -> str:
- prompt = "你是一位专业的文章总结大师,任务是基于用户提供的文本,提炼出文章的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。"
+ prompt = "你是一位专业的内容总结大师,任务是基于用户提供的资料提炼出核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。"
if reference:
prompt += f"\n{reference}"
- return prompt.strip()
+ return prompt.strip() + mermaid_syntax()
def beautify_mermaid(mermaid: str) -> str:
def replace(s: str) -> str:
s = s.replace("\n", "<br/>")
- s = s.replace(" ", "<br/>")
s = s.replace('"', """)
s = s.replace("'", "'")
s = s.replace("[", "[")
@@ -137,13 +173,257 @@ def beautify_mermaid(mermaid: str) -> str:
return f"---\nconfig:\n theme: neo\n look: neo\n---\n{mermaid.strip()}"
-async def save_mermaid_jpg_to_r2(mermaid: str) -> str:
+async def save_mermaid_jpg_to_r2(mermaid: str) -> tuple[str, str]:
+ """Save Mermaid image to R2.
+
+ Returns:
+ (image_url, local_path)
+ """
b64_str = base64.urlsafe_b64encode(mermaid.encode("utf-8")).decode("ascii")
save_path = Path(DOWNLOAD_DIR) / f"{hashlib.sha256(mermaid.encode()).hexdigest()}.jpg"
await download_file(f"https://mermaid.ink/img/{b64_str}?type=jpeg&theme=forest&width=2160", path=save_path, suffix=".jpg")
if save_path.is_file():
r2_key = f"TTL/365d/{save_path.name}"
await set_cf_r2(r2_key, data=save_path.read_bytes(), mime_type="image/jpeg", silent=True)
- save_path.unlink(missing_ok=True)
- return f"{DB.CF_R2_PUBLIC_URL}/{r2_key}"
- return ""
+ return f"{DB.CF_R2_PUBLIC_URL}/{r2_key}", save_path.as_posix()
+ return "", ""
+
+
+def summary_params(reference: str | None = None) -> dict:
+ return {
+ "gemini_generate_content_config": {
+ "system_instruction": system_prompt(reference),
+ "responseMimeType": "application/json",
+ "responseJsonSchema": JSON_SCHEMA,
+ },
+ "openai_responses_config": {
+ "instructions": system_prompt(reference),
+ "text": {
+ "format": {
+ "type": "json_schema",
+ "name": "ContentSummary",
+ "strict": True,
+ "description": "提炼出资料的核心内容,生成符合指定JSON格式的全文总结、分片内容和思维导图",
+ "schema": JSON_SCHEMA,
+ }
+ },
+ },
+ "gemini_append_grounding": False,
+ "openai_enable_tool_call": False,
+ "openai_append_tool_results": False,
+ "silent": True,
+ }
+
+
+def mermaid_syntax() -> str:
+ return """
+# Mermaid Flowcharts - Basic Syntax
+
+Flowcharts are composed of **nodes** (geometric shapes) and **edges** (arrows or lines). The Mermaid code defines how nodes and edges are made and accommodates different arrow types, multi-directional arrows, and any linking to and from subgraphs.
+
+### A node (default)
+
+```mermaid
+flowchart LR
+ id
+```
+
+```note
+The id is what is displayed in the box.
+```
+
+### A node with text
+
+It is also possible to set text in the box that differs from the id. If this is done several times, it is the last text
+found for the node that will be used. Also if you define edges for the node later on, you can omit text definitions. The
+one previously defined will be used when rendering the box.
+
+```mermaid
+flowchart LR
+ id1[This is the text in the box]
+```
+
+## Node shapes
+
+### A node with round edges
+
+```mermaid
+flowchart LR
+ id1(This is the text in the box)
+```
+
+### A stadium-shaped node
+
+```mermaid
+flowchart LR
+ id1([This is the text in the box])
+```
+
+### A node in a subroutine shape
+
+```mermaid
+flowchart LR
+ id1[[This is the text in the box]]
+```
+
+### A node in a cylindrical shape
+
+```mermaid
+flowchart LR
+ id1[(Database)]
+```
+
+### A node in the form of a circle
+
+```mermaid
+flowchart LR
+ id1((This is the text in the circle))
+```
+
+### A node in an asymmetric shape
+
+```mermaid
+flowchart LR
+ id1>This is the text in the box]
+```
+
+## Links between nodes
+
+Nodes can be connected with links/edges. It is possible to have different types of links or attach a text string to a link.
+
+### A link with arrow head
+
+```mermaid
+flowchart LR
+ A-->B
+```
+
+### An open link
+
+```mermaid
+flowchart LR
+ A --- B
+```
+
+### Text on links
+
+```mermaid
+flowchart LR
+ A-- This is the text! ---B
+```
+
+or
+
+```mermaid
+flowchart LR
+ A---|This is the text|B
+```
+
+### A link with arrow head and text
+
+```mermaid
+flowchart LR
+ A-->|text|B
+```
+
+or
+
+```mermaid
+flowchart LR
+ A-- text -->B
+```
+
+### Dotted link
+
+```mermaid
+flowchart LR
+ A-.->B;
+```
+
+### Dotted link with text
+
+```mermaid
+flowchart LR
+ A-. text .-> B
+```
+
+### Thick link
+
+```mermaid
+flowchart LR
+ A ==> B
+```
+
+### Thick link with text
+
+```mermaid
+flowchart LR
+ A == text ==> B
+```
+
+### An invisible link
+
+This can be a useful tool in some instances where you want to alter the default positioning of a node.
+
+```mermaid
+flowchart LR
+ A ~~~ B
+```
+
+### Chaining of links
+
+It is possible declare many links in the same line as per below:
+
+```mermaid
+flowchart LR
+ A -- text --> B -- text2 --> C
+```
+
+It is also possible to declare multiple nodes links in the same line as per below:
+
+```mermaid
+flowchart LR
+ a --> b & c--> d
+```
+
+You can then describe dependencies in a very expressive way. Like the one-liner below:
+
+```mermaid
+flowchart TB
+ A & B--> C & D
+```
+
+If you describe the same diagram using the basic syntax, it will take four lines. A
+word of warning, one could go overboard with this making the flowchart harder to read in
+markdown form. The Swedish word `lagom` comes to mind. It means, not too much and not too little.
+This goes for expressive syntaxes as well.
+
+```mermaid
+flowchart TB
+ A --> C
+ A --> D
+ B --> C
+ B --> D
+```
+
+## New arrow types
+
+There are new types of arrows supported:
+
+- circle edge
+- cross edge
+
+### Circle edge example
+
+```mermaid
+flowchart LR
+ A --o B
+```
+
+### Cross edge example
+
+```mermaid
+flowchart LR
+ A --x B
+```
+"""
src/messages/help.py
@@ -4,10 +4,10 @@ from config import PREFIX
from permission import check_service
-def social_media_help(chat_id: int | str, ctype: str, prefix: str):
+def social_media_help(chat_id: int | str, ctype: str):
"""Get the help message for social media preview."""
permission = check_service(cid=chat_id, ctype=ctype)
- msg = f"🔗**链接解析**: {prefix}\n🔄使用 `/retry` 回复消息强制重试"
+ msg = f"🔗**链接解析**: {PREFIX.SOCIAL_MEDIA}\n🔄使用 `/retry` 回复消息强制重试"
if permission["twitter"]:
msg += "\n🕊推特"
if permission["weibo"]:
src/messages/main.py
@@ -8,6 +8,7 @@ from pyrogram.types import Message
from ai.chat_summary import ai_chat_summary
from ai.main import ai_image_generation, ai_text_generation, ai_video_generation
+from ai.summary import ai_summary
from asr.voice_recognition import voice_to_text
from bridge.ocr import send_to_ocr_bridge
from config import FAVORITE, PREFIX, PROXY
@@ -104,6 +105,7 @@ async def process_message(
await ai_text_generation(client, message, **kwargs) # /ai
await ai_image_generation(client, message, **kwargs) # /gen
await ai_video_generation(client, message, **kwargs) # /gvid
+ await ai_summary(client, message, **kwargs) # /summary
if asr:
await voice_to_text(client, message, **kwargs) # /asr
if audio_extract:
@@ -123,7 +125,7 @@ async def process_message(
if history:
await query_chat_history(client, message, **kwargs) # /history
if summary:
- await ai_chat_summary(client, message, **kwargs) # /summary
+ await ai_chat_summary(client, message, **kwargs) # /chatsum
if danmu:
await query_danmu(client, message, **kwargs) # /danmu
if favorite:
@@ -233,7 +235,7 @@ async def preview_social_media(
# without reply, send docs if message only contains prefix command
if not message.reply_to_message:
await delete_message(message)
- docs = social_media_help(info["cid"], info["ctype"], PREFIX.SOCIAL_MEDIA)
+ docs = social_media_help(info["cid"], info["ctype"])
helps = await send2tg(client, message, texts=docs, **kwargs)
await asyncio.sleep(30)
return await delete_message(helps)
src/config.py
@@ -405,6 +405,7 @@ class AI:
OPENAI_API_KEYS = os.getenv("AI_OPENAI_API_KEYS", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
OPENAI_BASE_URL = os.getenv("AI_OPENAI_BASE_URL", "https://api.openai.com/v1")
TOOL_CALL_MODEL_ALIAS = os.getenv("AI_TOOL_CALL_MODEL_ALIAS", "tool-call")
+ AI_SUMMARY_MODEL_ALIAS = os.getenv("AI_SUMMARY_MODEL_ALIAS", "gemini")
PODCAST_SUMMARY_MODEL_ALIAS = os.getenv("PODCAST_SUMMARY_MODEL_ALIAS", "podcast-summary")
SUBTITLE_SUMMARY_MODEL_ALIAS = os.getenv("SUBTITLE_SUMMARY_MODEL_ALIAS", "subtitle-summary")
CHAT_SUMMARY_MODEL_ALIAS = os.getenv("CHAT_SUMMARY_MODEL_ALIAS", "chat-summary")