Commit f97976a
Changed files (37)
src
ai
danmu
podcast
subtitles
summarize
src/ai/texts/contexts.py
@@ -9,27 +9,29 @@ import mimetypes
from pathlib import Path
from typing import TYPE_CHECKING, Literal
+import anyio
from anthropic import AsyncAnthropic
from glom import Coalesce, glom
from google import genai
-from google.genai.types import FileState, HttpOptions, Part, UploadFileConfig
+from google.genai.types import FileData, FileState, HttpOptions, Part, UploadFileConfig
from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
-from pyrogram.types import Message
+from pyrogram.types import Chat, Document, Message
from ai.utils import BOT_TIPS, clean_context
from asr.utils import GEMINI_AUDIO_EXT, downsampe_audio
from config import AI, DOWNLOAD_DIR, PROXY, TID
from database.r2 import head_cf_r2, set_cf_r2
from messages.parser import parse_msg
-from utils import convert2md, read_text
+from others.download_external import AUDIO_FORMAT, VIDEO_FORMAT
+from utils import convert2md, digest, guess_mime, read_text
if TYPE_CHECKING:
from io import BytesIO
TXT_EXT = [".sh", ".json", ".xml", ".tex"] # treat these as txt file
-MARKDOWN_EXT = [".pdf", ".html", ".docx", ".pptx", ".xls", ".xlsx", ".epub"] # convert to markdown
+MARKDOWN_EXT = [".html", ".docx", ".pptx", ".xls", ".xlsx", ".epub"] # convert to markdown
# gemini has built-in support for these extensions
GEMINI_EXT = [".pdf", ".html", ".css", ".csv", ".xml", ".rtf", ".mp3", ".wav", ".ogg", ".aac", ".flac", ".jpg", ".jpeg", ".webp", ".png", ".heic", ".heif"]
# gemini has built-in support for these mime types
@@ -56,13 +58,14 @@ async def base64_media(client: Client, message: Message) -> dict:
}
-async def get_openai_completion_contexts(client: Client, message: Message, *, add_sender: bool | None = None) -> list[dict]:
+async def get_openai_completion_contexts(client: Client, message: Message, params: dict) -> list[dict]:
"""Generate OpenAI chat completion contexts."""
chains = await full_chain_contexts(client, message, order="asc") # old to new
+ add_sender = params.get("add_sender")
if add_sender is None:
add_sender = is_multi_user_chat(chains)
messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
- contexts = []
+ contexts: list = []
for msg in messages:
info = parse_msg(msg, silent=True)
@@ -114,9 +117,59 @@ async def get_openai_completion_contexts(client: Client, message: Message, *, ad
logger.warning(f"Download media from message failed: {e}")
contexts.append(context)
+ additional_contexts = await parse_openai_chat_additional_contexts(params)
+ if contexts and contexts[-1]["role"] == "user":
+ contexts[-1]["content"].extend(additional_contexts)
+ else:
+ contexts.extend([{"role": "user", "content": additional_contexts}])
+
return [ctx for ctx in contexts if ctx.get("content")]
+async def parse_openai_chat_additional_contexts(params: dict) -> list[dict]:
+ """Parse additional contexts.
+
+ Type: list[dict]
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
+ """
+ if not params.get("additional_contexts"):
+ return []
+
+ # For OpenAI、OpenRouter、Volcengine, please use Responses API instead
+ # Currently, this is for xiaomi mimo only
+ perms = {
+ "image": bool(params.get("allow_image")),
+ "video": bool(params.get("allow_video")),
+ "audio": bool(params.get("allow_audio")),
+ "file": bool(params.get("allow_file")),
+ }
+ types = {
+ "image": "image_url",
+ "video": "video_url",
+ "audio": "inpu_audio",
+ }
+
+ contexts = []
+
+ for item in params["additional_contexts"]:
+ item_type = item.get("type")
+ path = Path(item.get("path", ""))
+ if perms.get(item_type) and path.is_file():
+ if item_type == "file" and (md := convert2md(path=path)):
+ contexts.append({"type": "text", "text": md})
+ continue
+ mime = item.get("mime_type") or guess_mime(path)
+ data_uri = f"data:{mime};base64,{await encode_file(path)}"
+ payload = {"type": types[item_type]} # {"type": "image_url"}
+ payload[types[item_type]] = data_uri # {"image_url": {"data:image/jpeg;base64, ..."}}
+ contexts.append(payload)
+
+ return contexts
+
+
async def get_openai_response_contexts(client: Client, message: Message, params: dict) -> tuple[str, list[dict]]:
"""Generate OpenAI response contexts.
@@ -150,7 +203,13 @@ async def get_openai_response_contexts(client: Client, message: Message, params:
messages.reverse() # old to new
if params.get("add_sender") is None:
params["add_sender"] = is_multi_user_chat(messages)
- return previous_response_id, [ctx for msg in messages if (ctx := await single_openai_response_context(client, msg, params))]
+ contexts = [ctx for msg in messages if (ctx := await single_openai_response_context(client, msg, params))]
+ additional_contexts = await parse_openai_response_additional_contexts(client, params)
+ if contexts and contexts[-1]["role"] == "user":
+ contexts[-1]["content"].extend(additional_contexts)
+ else:
+ contexts.extend([{"role": "user", "type": "message", "content": additional_contexts}])
+ return previous_response_id, contexts
async def single_openai_response_context(client: Client, message: Message, params: dict) -> dict:
@@ -242,7 +301,63 @@ async def single_openai_response_context(client: Client, message: Message, param
return context if context["content"] else {}
-async def get_openai_file_id(client: Client, message: Message, params: dict) -> str:
+async def parse_openai_response_additional_contexts(client: Client, params: dict) -> list[dict]:
+ """Parse additional contexts.
+
+ Type: list[dict]
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
+ """
+ if not params.get("additional_contexts"):
+ return []
+
+ media_send_as = params.get("openai_media_send_as", "base64")
+ perms = {
+ "image": bool(params.get("allow_image")),
+ "video": bool(params.get("allow_video")),
+ "audio": bool(params.get("allow_audio")),
+ "file": bool(params.get("allow_file")),
+ }
+
+ contexts = []
+
+ for item in params["additional_contexts"]:
+ item_type = item.get("type")
+ path = Path(item.get("path", ""))
+ if perms.get(item_type) and path.is_file():
+ # Base64
+ if media_send_as == "base64":
+ mime = item.get("mime_type") or guess_mime(path)
+ data_uri = f"data:{mime};base64,{await encode_file(path)}"
+ payload = {"type": f"input_{item_type}"} # input_image, input_video, ...
+ if item_type == "file":
+ payload.update({"filename": path.name, "file_data": data_uri})
+ else:
+ payload[f"{item_type}_url"] = data_uri # image_url, video_url, audio_url
+ contexts.append(payload)
+
+ # File ID
+ elif media_send_as == "file_id":
+ hash_id = int(digest(path, length=11, to_int=True))
+ message = Message(id=hash_id, chat=Chat(id=hash_id), document=Document(file_id="", file_unique_id="", file_size=path.stat().st_size))
+ kwargs = {"force_audio_to_aac": True} if item_type == "audio" else {} # handle audio special case
+ if file_id := await get_openai_file_id(client, message, params, fpath=path, keep_file=True, **kwargs):
+ contexts.append({"type": f"input_{item_type}", "file_id": file_id})
+
+ return contexts
+
+
+async def get_openai_file_id(
+ client: Client,
+ message: Message,
+ params: dict,
+ fpath: str | Path | None = None,
+ *,
+ force_audio_to_aac: bool = False,
+ keep_file: bool = False,
+) -> str:
def get_real_baseurl() -> str:
base_url = str(params["base_url"]) or ""
default_headers = params.get("default_headers", {})
@@ -270,19 +385,21 @@ async def get_openai_file_id(client: Client, message: Message, params: dict) ->
api_key=api_key,
http_client=DefaultAsyncHttpxClient(proxy=params["proxy"]) if params.get("proxy") else None,
)
- fpath: str | Path = await client.download_media(message) # ty:ignore[invalid-assignment]
+ if fpath is None:
+ fpath: str = await client.download_media(message) # ty:ignore[invalid-assignment]
try:
# hotfix: convert audio to aac
- if message.audio and not str(fpath).endswith(".aac"):
- fpath = await downsampe_audio(fpath, ext="aac", codec="aac")
+ if force_audio_to_aac or (message.audio and not str(fpath).endswith(".aac")):
+ fpath: Path = await downsampe_audio(fpath, ext="aac", codec="aac")
resp = await openai.files.create(file=Path(fpath), purpose="user_data")
while resp.status in ["processing", "uploaded"]:
logger.trace(f"Upload media to OpenAI processing: {resp.model_dump()}")
- await asyncio.sleep(3)
+ await asyncio.sleep(1)
resp = await openai.files.retrieve(file_id=resp.id)
if resp.status in ["active", "processed"]:
- Path(fpath).unlink(missing_ok=True)
await set_cf_r2(r2_key, data=resp.model_dump(), metadata={"file_id": resp.id})
+ if not keep_file:
+ Path(fpath).unlink(missing_ok=True)
return resp.id
logger.error(f"Upload media to OpenAI failed: {resp.model_dump()}")
except Exception as e:
@@ -290,7 +407,14 @@ async def get_openai_file_id(client: Client, message: Message, params: dict) ->
return ""
-async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Client, *, add_sender: bool | None = None) -> list[dict]:
+async def get_gemini_contexts(
+ client: Client,
+ message: Message,
+ gemini: genai.Client,
+ *,
+ add_sender: bool | None = None,
+ additional_contexts: list[dict] | None = None,
+) -> list[dict]:
"""Generate Gemini contexts from old to new.
Returns:
@@ -300,7 +424,7 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
if add_sender is None:
add_sender = is_multi_user_chat(chains)
messages = chains[-int(AI.MAX_CONTEXTS_NUM) :]
- contexts = []
+ contexts: list = []
for msg in messages:
info = parse_msg(msg, silent=True)
role = "model" if BOT_TIPS in info["text"] else "user"
@@ -337,17 +461,63 @@ async def get_gemini_contexts(client: Client, message: Message, gemini: genai.Cl
logger.warning(f"Download media from message failed: {e}")
if parts:
contexts.append({"role": role, "parts": parts})
+
+ additional_parts = await parse_gemini_additional_contexts(client, gemini, additional_contexts)
+ if contexts and contexts[-1]["role"] == "user":
+ contexts[-1]["parts"].extend(additional_parts)
+ else:
+ contexts.extend([{"role": "user", "parts": additional_parts}])
return [ctx for ctx in contexts if len(ctx.get("parts"))]
-async def get_gemini_file_id(client: Client, message: Message, gemini: genai.Client, fname: str, mtype: str) -> dict:
+async def parse_gemini_additional_contexts(client: Client, gemini: genai.Client, contexts: list[dict] | None = None) -> list:
+ """Parse additional contexts.
+
+ Type: list[dict]
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
+ {"type": "youtube", "url": "https://www.youtube.com/watch?v=videoid"}
+ """
+ if not contexts:
+ return []
+
+ parts = []
+
+ for item in contexts:
+ item_type = item.get("type", "")
+ if item_type == "youtube":
+ parts.append(Part(file_data=FileData(file_uri=item["url"])))
+ continue
+
+ path = Path(item.get("path", ""))
+ if not path.is_file():
+ continue
+ hash_id = int(digest(path, length=11, to_int=True))
+ message = Message(id=hash_id, chat=Chat(id=hash_id))
+ if uploaded := await get_gemini_file_id(client, message, gemini, path, item_type, keep_file=True):
+ parts.append(Part.from_uri(file_uri=uploaded["file_id"], mime_type=uploaded["mime_type"]))
+ return parts
+
+
+async def get_gemini_file_id(
+ client: Client,
+ message: Message,
+ gemini: genai.Client,
+ fpath: str | Path,
+ mtype: str,
+ *,
+ keep_file: bool = False,
+) -> dict:
"""Get Gemini file id from message.
Returns:
file_id: str
mime_type: str
"""
- if mtype not in ["video", "photo", "audio", "voice"] and mtype not in GEMINI_MIME and not any(fname.endswith(ext) for ext in GEMINI_EXT):
+ fname = Path(fpath).name
+ if mtype not in ["video", "photo", "image", "audio", "voice"] and not any(fname.endswith(ext) for ext in GEMINI_EXT):
return {}
cache_hour = AI.GEMINI_FILES_TTL // 3600
@@ -364,7 +534,10 @@ async def get_gemini_file_id(client: Client, message: Message, gemini: genai.Cli
except Exception as e:
logger.warning(f"Get file id from Gemini failed: {e}")
try:
- fpath: str = await client.download_media(message) # type: ignore
+ if isinstance(fpath, Path) and fpath.is_file():
+ fpath = fpath.as_posix()
+ else:
+ fpath = await client.download_media(message) # type: ignore
if mtype in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
audio_path = await downsampe_audio(fpath)
fpath = audio_path.as_posix()
@@ -375,7 +548,8 @@ async def get_gemini_file_id(client: Client, message: Message, gemini: genai.Cli
upload = await app.aio.files.get(name=upload.name) # type: ignore
if upload.state == FileState.ACTIVE and upload.uri:
await set_cf_r2(r2_key, data=json.loads(upload.model_dump_json()), metadata={"name": upload.name})
- Path(fpath).unlink(missing_ok=True)
+ if not keep_file:
+ Path(fpath).unlink(missing_ok=True)
return {"file_id": upload.uri, "mime_type": upload.mime_type}
except Exception as e:
logger.error(f"Upload media to Gemini failed: {e}")
@@ -489,10 +663,50 @@ async def full_chain_contexts(client: Client, message: Message, order: Literal["
return sorted(messages, key=lambda x: x.id, reverse=order == "desc")
-async def context_bytes(client: Client, message: Message) -> int:
- """Count bytes of all messages in the reply chain."""
- chains = await full_chain_contexts(client, message)
- return sum(message_bytes(m) for m in chains)
+async def context_types(client: Client, message: Message, additional_contexts: list[dict]) -> dict:
+ """Get context types of a message."""
+ text = False
+ video = False
+ audio = False
+ image = False
+ file = False
+ youtube = False
+ for msg in await full_chain_contexts(client, message):
+ if msg.audio:
+ audio = True
+ if msg.photo:
+ image = True
+ if msg.video:
+ video = True
+ if msg.document:
+ mime = glom(msg, "document.mime_type", default="") or ""
+ fname = glom(msg, "document.file_name", default="") or ""
+ if mime.startswith("image/"):
+ image = True
+ elif mime.startswith("audio/") or Path(fname).suffix in AUDIO_FORMAT:
+ audio = True
+ elif mime.startswith("video/") or Path(fname).suffix in VIDEO_FORMAT:
+ video = True
+ elif mime.startswith("text/") or Path(fname).suffix in TXT_EXT or Path(fname).suffix in MARKDOWN_EXT:
+ text = True
+ else:
+ file = True
+ if msg.text or msg.caption:
+ text = True
+ # additional_contexts are parsed from function `parse_summary_sources` in `src/summarize/utils.py`
+ for ctx in additional_contexts:
+ if ctx["type"] == "image":
+ image = True
+ if ctx["type"] == "video":
+ video = True
+ if ctx["type"] == "audio":
+ audio = True
+ if ctx["type"] == "file":
+ file = True
+ if ctx["type"] == "youtube":
+ youtube = True
+
+ return {"text": text, "video": video, "audio": audio, "image": image, "file": file, "youtube": youtube}
def message_bytes(message: Message) -> int:
@@ -510,3 +724,12 @@ def is_multi_user_chat(messages: list[Message]) -> bool:
uids.discard(TID.ME)
uids.discard(0)
return len(uids) > 1
+
+
+async def encode_file(path: str | Path) -> str:
+ """Encode a file to base64."""
+ chunks = []
+ async with await anyio.Path(path).open("rb") as f:
+ while chunk := await f.read(65535): # 64KB
+ chunks.append(base64.b64encode(chunk).decode("utf-8"))
+ return "".join(chunks)
src/ai/texts/gemini.py
@@ -36,6 +36,7 @@ async def gemini_chat_completion(
gemini_generate_content_config: str | dict = "",
gemini_proxy: str | None = PROXY.GOOGLE,
gemini_append_grounding: bool = True,
+ additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
skills: str = "",
hide_thinking: bool = False,
add_sender: bool | None = None,
@@ -61,7 +62,16 @@ async def gemini_chat_completion(
try:
http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
gemini = genai.Client(api_key=api_key, http_options=http_options)
- params: dict = {"model": model_id, "contents": await get_gemini_contexts(client, message, gemini, add_sender=add_sender)}
+ params: dict = {
+ "model": model_id,
+ "contents": await get_gemini_contexts(
+ client,
+ message,
+ gemini,
+ add_sender=add_sender,
+ additional_contexts=additional_contexts,
+ ),
+ }
if skills:
gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
if conf := literal_eval(gemini_generate_content_config):
src/ai/texts/models.py
@@ -3,8 +3,10 @@
import re
from loguru import logger
+from pyrogram.client import Client
from pyrogram.types import Message
+from ai.texts.contexts import context_types
from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
from config import AI, PREFIX
from database.kv import get_cf_kv
@@ -217,3 +219,54 @@ async def get_config_by_model_name(model_name: str) -> list[dict]:
model_config["model_id"] = model_id
model_configs.append(model_config.copy())
return model_configs
+
+
+async def reorder_model_configs(client: Client, message: Message, configs: list[dict], params: dict) -> list[dict]:
+ """Reorder model configs by strategy.
+
+ prefer gemini model if types have youtube
+ then prefer audio model if types have audio
+ then prefer video model if types have video
+ then prefer image model if types have image
+ then prefer file model if types have file
+ then prefer text model if types have text
+
+ Returns:
+ model_configs
+ """
+ types = await context_types(client, message, params.get("additional_contexts", []))
+ if not any((types.get("youtube"), types.get("audio"), types.get("video"), types.get("image"), types.get("file"))):
+ return configs # text only
+
+ def is_preferred(config: dict) -> bool:
+ api_type = config.get("api_type", "")
+
+ if api_type == "gemini":
+ return True
+
+ is_openai = api_type.startswith("openai")
+
+ # youtube > audio > video > image > file
+ if types.get("youtube"):
+ return False # only Gemini can handle YouTube
+ if types.get("audio"):
+ return is_openai and config.get("openai_allow_audio", False)
+ if types.get("video"):
+ return is_openai and config.get("openai_allow_video", False)
+ if types.get("image"):
+ return is_openai and config.get("openai_allow_image", False)
+ if types.get("file"):
+ return is_openai and config.get("openai_allow_file", False)
+
+ return False
+
+ preferred_configs = []
+ remaining_configs = []
+
+ for config in configs:
+ if is_preferred(config):
+ preferred_configs.append(config)
+ else:
+ remaining_configs.append(config)
+
+ return preferred_configs + remaining_configs
src/ai/texts/openai_chat.py
@@ -1,8 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import contextlib
+import json
+from json import JSONDecodeError
+from typing import Literal
from glom import glom
+from jsonschema import ValidationError, validate
from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
@@ -34,6 +38,12 @@ async def openai_chat_completions(
openai_contexts: list[dict] | None = None,
openai_tools: list[dict] | None = None,
skills: str = "",
+ openai_allow_image: bool = True, # whether to allow image in input modalities
+ openai_allow_video: bool = False, # whether to allow video in input modalities
+ openai_allow_audio: bool = False, # whether to allow audio in input modalities
+ openai_allow_file: bool = False, # whether to allow file in input modalities
+ openai_media_send_as: Literal["base64", "file_id"] = "base64",
+ additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
hide_thinking: bool = False,
add_sender: bool | None = None,
silent: bool = False,
@@ -62,7 +72,20 @@ async def openai_chat_completions(
openai_client |= {"default_headers": literal_eval(openai_default_headers)}
if openai_proxy:
openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
- contexts = openai_contexts or await get_openai_completion_contexts(client, message, add_sender=add_sender)
+
+ contexts = openai_contexts or await get_openai_completion_contexts(
+ client,
+ message,
+ params={
+ "add_sender": add_sender,
+ "allow_image": openai_allow_image,
+ "allow_video": openai_allow_video,
+ "allow_audio": openai_allow_audio,
+ "allow_file": openai_allow_file,
+ "openai_media_send_as": openai_media_send_as,
+ "additional_contexts": additional_contexts,
+ },
+ )
if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
contexts.insert(0, {"role": "system", "content": openai_system_prompt})
if skills:
@@ -94,6 +117,8 @@ async def openai_chat_completions(
max_retries=max_retries,
**kwargs,
)
+ if not is_valid_response(resp, glom(params, "response_format.json_schema.schema", default={})):
+ continue
if resp.get("texts") or resp.get("tool_name"):
resp |= {
"success": True,
@@ -250,3 +275,18 @@ def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
system_prompt.append({"type": "text", "text": skills})
contexts[0] = {"role": "system", "content": system_prompt}
return contexts
+
+
+def is_valid_response(resp: dict, schema: dict) -> bool:
+ """Check if the response is valid."""
+ if not schema:
+ return bool(resp.get("texts"))
+ if not resp.get("texts"):
+ return False
+ try:
+ data = json.loads(resp["texts"])
+ validate(instance=data, schema=schema)
+ except (JSONDecodeError, ValidationError) as e:
+ logger.error(f"Invalid JSONSchema response: {e}")
+ return False
+ return True
src/ai/texts/openai_response.py
@@ -42,6 +42,7 @@ async def openai_responses_api(
openai_allow_audio: bool = False, # whether to allow audio in input modalities
openai_allow_file: bool = False, # whether to allow file in input modalities
openai_media_send_as: Literal["base64", "file_id"] = "base64",
+ additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
skills: str = "",
openai_append_tool_results: bool = True,
hide_thinking: bool = False,
@@ -96,6 +97,7 @@ async def openai_responses_api(
"allow_file": openai_allow_file,
"openai_media_send_as": openai_media_send_as,
"add_sender": add_sender,
+ "additional_contexts": additional_contexts,
},
)
params = {}
@@ -182,6 +184,7 @@ async def single_api_response(
is_reasoning = False
async for chunk in await openai.responses.create(**params):
resp = trim_none(chunk.model_dump())
+ resp.pop("item_id", None) # too noisy
logger.trace(resp)
error = await parse_error(resp, retry, max_retries, status_msg)
if error["retry"]:
src/ai/texts/tool_call.py
@@ -73,7 +73,7 @@ async def get_tool_call_results(client: Client, message: Message, **kwargs) -> d
Returns:
dict: {"texts": str, "thoughts": str, "prefix": str, "sent_messages": list[Message]}
"""
- contexts = await get_openai_completion_contexts(client, message)
+ contexts = await get_openai_completion_contexts(client, message, kwargs)
if not contexts:
return {}
src/ai/chat_summary.py
@@ -108,7 +108,7 @@ async def ai_chat_summary(
client: Client,
message: Message,
summary_prefix: str | None = None,
- summary_model_id: str = AI.CHAT_SUMMARY_MODEL_ALIAS,
+ summary_chat_model: str = AI.CHAT_SUMMARY_MODEL_ALIAS,
**kwargs,
):
"""GPT summary of the message history.
@@ -117,7 +117,7 @@ async def ai_chat_summary(
client (Client): The Pyrogram client.
message (Message): The trigger message object.
summary_prefix (str | None): Prefix string of the response message.
- summary_model_id (str, optional): The model id to use for AI summary.
+ summary_chat_model (str, optional): The model id to use for AI summary.
"""
# send docs if message == "/summary"
if equal_prefix(message.text, prefix=[PREFIX.CHAT_SUMMARY, PREFIX.COMBINATION]) and not message.reply_to_message:
@@ -191,7 +191,7 @@ async def ai_chat_summary(
ai_msg = Message( # Construct a message for AI
id=rand_number(),
chat=message.chat,
- text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{summary_model_id} {SYSTEM_PROMPT} {parsed['history']}"),
+ text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{summary_chat_model} {SYSTEM_PROMPT} {parsed['history']}"),
)
ai_res = await ai_text_generation(client, ai_msg, silent=True)
if texts := ai_res.get("texts"):
src/ai/main.py
@@ -11,7 +11,7 @@ from ai.images.openai_img import openai_image_generation
from ai.images.post import http_post_image_generation
from ai.texts.claude import anthropic_responses
from ai.texts.gemini import gemini_chat_completion
-from ai.texts.models import get_config_by_model_alias, get_text_model_configs
+from ai.texts.models import get_config_by_model_alias, get_text_model_configs, reorder_model_configs
from ai.texts.openai_chat import openai_chat_completions
from ai.texts.openai_response import openai_responses_api
from ai.texts.tool_call import get_tool_call_results
@@ -37,6 +37,8 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
if not model_configs:
return {}
+ model_configs = await reorder_model_configs(client, message, model_configs, kwargs)
+
def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
"""Handle API response.
src/ai/summary.py
@@ -1,213 +0,0 @@
-#!/venv/bin/python
-# -*- coding: utf-8 -*-
-import base64
-import json
-import re
-import zlib
-from datetime import datetime
-from pathlib import Path
-
-from loguru import logger
-from pyrogram.client import Client
-from pyrogram.types import Chat, Message
-from pyrogram.types.messages_and_media.message import Str
-
-from ai.main import ai_text_generation
-from ai.texts.contexts import context_bytes
-from ai.texts.gemini import gemini_chat_completion
-from ai.texts.models import get_config_by_model_alias
-from ai.texts.openai_response import openai_responses_api
-from ai.utils import deep_merge
-from config import AI, DB, DOWNLOAD_DIR, PREFIX
-from database.r2 import set_cf_r2
-from messages.help import social_media_help
-from messages.sender import send2tg
-from messages.utils import equal_prefix, set_reaction, startswith_prefix
-from networking import download_file, shorten_url
-from publish import telegraph_aipage
-from schema import AIPage, ContentExtraction, get_schema
-from utils import count_subtitles, digest, rand_number, to_dt
-
-
-async def ai_summary(client: Client, message: Message, summary_model_id: str = AI.AI_SUMMARY_MODEL_ALIAS, **kwargs):
- if not startswith_prefix(message.content, prefix=PREFIX.AI_SUMMARY):
- return
- this_msg = message
- if equal_prefix(message.content, PREFIX.AI_SUMMARY):
- if not message.reply_to_message:
- await send2tg(client, message, texts=social_media_help(message), **kwargs)
- return
- message = message.reply_to_message
- models = await get_config_by_model_alias(summary_model_id, fallback_to_default=False)
- models = [x for x in models if x.get("api_type") in ["gemini", "openai_responses"]] # only support gemini & openai_responses models
- if not models:
- return
- mbytes = await context_bytes(client, message)
- if mbytes > 40 * 1024 * 1024: # prefer gemini for large files (40MB)
- models = sorted(models, key=lambda x: x.get("api_type") == "gemini", reverse=True)
- await set_reaction(client, this_msg, "👌")
- for model_config in models:
- res = {}
- params = deep_merge(model_config, kwargs, summary_params())
- if params["api_type"] == "gemini":
- res = await gemini_chat_completion(client, message, **params)
- elif params["api_type"] == "openai_responses":
- res = await openai_responses_api(client, message, **params)
- if not res.get("texts"):
- continue
- texts, _, _ = await parse_summary(res["texts"])
- await send2tg(client, message, texts="**🤖AI导读**\n" + texts, **kwargs)
- await set_reaction(client, this_msg, "")
- return
-
-
-async def summarize(
- article: str | None = None,
- transcripts: str | None = None,
- reference: str | None = None,
- model: str = "gemini",
- title: str | None = None,
- author: str | None = None,
- url: str | None = None,
- date: str | datetime | None = None,
- description: str | None = None,
- ttl: str | None = None,
-) -> dict:
- title = title or "AI导读"
- if article is None and transcripts is None:
- raise ValueError("必须传入 article 或 transcripts 其中一个参数")
- if article is not None and transcripts is not None:
- raise ValueError("不能同时传入 article 和 transcripts 参数")
- source = article or transcripts or ""
- if count_subtitles(source) < 200: # skip short article
- return {}
- res = await ai_text_generation(
- "fake-client", # type: ignore
- message=Message(
- id=rand_number(),
- chat=Chat(id=rand_number()),
- text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{model} {source.strip()}"),
- ),
- **summary_params(reference),
- )
- if not res.get("texts", ""):
- return {}
- texts, mermaid_img_url, mermaid_pako_url = await parse_summary(res["texts"])
- summary = ContentExtraction.model_validate_json(res["texts"])
- page = AIPage(
- title=title,
- author=author,
- url=url,
- date=to_dt(date),
- description=description,
- summary=summary,
- transcripts=transcripts,
- mermaid_img=mermaid_img_url,
- mermaid_url=mermaid_pako_url,
- )
- if telegraph_url := await telegraph_aipage(page, ttl=ttl):
- res["telegraph_url"] = telegraph_url
- res["texts"] = f"**🤖[AI导读]({telegraph_url})**\n" + texts
- else:
- res["telegraph_url"] = None
- res["texts"] = "**🤖AI导读**\n" + texts
- return res
-
-
-async def parse_summary(texts: str) -> tuple[str, str, str]:
- """Parse the summary JSON string.
-
- Returns:
- (summary_texts, mermaid_img_url, mermaid_pako_url)
- """
- try:
- summary = ContentExtraction.model_validate_json(texts)
- mermaid = beautify_mermaid(summary.mermaid)
- img_url, pako_url = await publish_mermaid(mermaid)
- parsed = f"{summary.overview}\n⚡️**章节速览**"
- for section in summary.sections:
- parsed += f"\n{section.emoji}**{section.title}**"
- if section.start:
- start = section.start.removeprefix("00:") if len(section.start) > 5 else section.start
- parsed += f" [{start}]"
- parsed += f"\n{section.content}"
- logger.success(parsed)
- except Exception as e:
- logger.error(f"Error parsing summary: {e}")
- return texts, "", ""
- return parsed, img_url, pako_url
-
-
-def system_prompt(reference: str | None = None) -> str:
- prompt = "你是一位专业的内容提炼大师,任务是基于用户提供的资料,生成用户无需阅读完整原文档就能清晰理解主要事件、观点、结论的内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。"
- if reference:
- prompt += f"\n{reference}"
- return prompt.strip()
-
-
-def beautify_mermaid(mermaid: str) -> str:
- def replace(s: str) -> str:
- s = s.replace("\n", "<br/>")
- s = s.replace('"', """)
- s = s.replace("'", "'")
- s = s.replace("[", "[")
- s = s.replace("]", "]")
- s = s.replace("(", "#40;")
- s = s.replace(")", "#41;")
- return s.replace("@", "#64;")
-
- def callback(match: re.Match[str]):
- original = match.group(1)
- new_text = replace(original)
- return f"[{new_text}]"
-
- mermaid = re.sub(r"\[(.*?)\]", callback, mermaid.strip())
- return f"---\nconfig:\n theme: neo\n look: neo\n---\n{mermaid.strip()}"
-
-
-async def publish_mermaid(mermaid: str) -> tuple[str, str]:
- """Save Mermaid image to R2.
-
- Returns:
- (image_url, pako_url)
- """
- b64_str = base64.urlsafe_b64encode(mermaid.encode("utf-8")).decode("ascii")
- save_path = Path(DOWNLOAD_DIR) / f"{digest(mermaid)}.jpg" # noqa: S324
- r2_key = f"TTL/365d/{save_path.name}"
- img_url = f"{DB.CF_R2_PUBLIC_URL}/{r2_key}"
- if await download_file(f"https://mermaid.ink/img/{b64_str}?type=jpeg&theme=forest&width=2160", path=save_path, suffix=".jpg"):
- img_url = await shorten_url(img_url, alias=digest(mermaid, 16))
- mermaid = mermaid.replace("\ngraph LR", f"\n%% {img_url}\ngraph LR")
- # generate pako url for mermaid image
- json_str = json.dumps({"code": mermaid.strip()}, separators=(",", ":"))
- compressed_bytes = zlib.compress(json_str.encode("utf-8"), level=9)
- pako_b64_str = base64.urlsafe_b64encode(compressed_bytes).decode("utf-8").rstrip("=")
- pako_url = await shorten_url(f"https://mermaid.live/view#pako:{pako_b64_str}", alias=digest(pako_b64_str, 16))
-
- if save_path.is_file():
- await set_cf_r2(r2_key, data=save_path.read_bytes(), mime_type="image/jpeg", silent=True)
- save_path.unlink(missing_ok=True)
- return img_url, pako_url
- return "", ""
-
-
-def summary_params(reference: str | None = None) -> dict:
- return {
- "gemini_generate_content_config": {"system_instruction": system_prompt(reference), "responseMimeType": "application/json", "responseJsonSchema": get_schema("content_extraction")},
- "openai_responses_config": {
- "instructions": system_prompt(reference),
- "text": {
- "format": {
- "type": "json_schema",
- "name": "ContentExtraction",
- "strict": True,
- "description": "精准提炼资料的核心主题、关键观点、主要结论及各片段核心内容,确保输出内容全面覆盖资料的关键信息,用户仅通过总结即可掌握信息全貌。",
- "schema": get_schema("content_extraction"),
- }
- },
- },
- "gemini_append_grounding": False,
- "openai_enable_tool_call": False,
- "openai_append_tool_results": False,
- "silent": True,
- }
src/danmu/entrypoint.py
@@ -127,7 +127,7 @@ async def query_danmu(client: Client, message: Message, **kwargs):
html = f'<!DOCTYPE html><html><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>{qtype}查询结果</title><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/water.css@2/out/water.css"></head><body><article>{html}</article></body></html>'
try:
with BytesIO(html.encode("utf-8")) as f:
- await status_msg.edit_media(file_name=f"{qtype}查询结果.html", media=InputMediaDocument(f, file_name=f"{qtype}查询结果.html", caption=caption))
+ await status_msg.edit_media(media=InputMediaDocument(f, file_name=f"{qtype}查询结果.html", caption=caption))
except MediaCaptionTooLong:
save_path = Path(DOWNLOAD_DIR).joinpath(f"{qtype}查询结果.html")
save_path.write_text(html)
src/messages/database.py
@@ -15,7 +15,7 @@ from messages.utils import sender_markdown_to_html
from utils import to_int, true
-async def save_messages(messages: list[Message | None], key: str, metadata: dict | None = None) -> bool:
+async def save_messages(messages: list[Message], key: str, metadata: dict | None = None) -> bool:
"""Save the messages to DB.
data format:
src/messages/main.py
@@ -8,7 +8,6 @@ from pyrogram.types import Message
from ai.chat_summary import ai_chat_summary
from ai.main import ai_image_generation, ai_text_generation, ai_video_generation
-from ai.summary import ai_summary
from asr.voice_recognition import voice_to_text
from bridge.ocr import send_to_ocr_bridge
from config import FAVORITE, PREFIX, PROXY
@@ -49,6 +48,7 @@ from preview.xiaohongshu import preview_xhs
from price.entrypoint import get_asset_price
from quotly.quotly import quote_message
from subtitles.subtitle import get_subtitle
+from summarize.main import ai_summary
from tts.tts import text_to_speech
from utils import to_int, true
from ytdlp.main import preview_ytdlp
@@ -167,6 +167,7 @@ async def preview_social_media(
*,
need_prefix: bool = True, # Need prefix to parse social media. (/dl)
prepend_sender_user: bool = False,
+ social: bool = True, # Parse social media link
douyin: bool = True, # Parse Douyin
tiktok: bool = True, # Parse TikTok
instagram: bool = True, # Parse Instagram
@@ -199,6 +200,8 @@ async def preview_social_media(
show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
"""
+ if not social:
+ return None
# these commands are handled in `process_message`
ignore_prefix = [
PREFIX.AI_IMG_GENERATION,
src/messages/sender.py
@@ -14,7 +14,7 @@ from config import CAPTION_LENGTH, TID
from messages.parser import get_thread_id
from messages.preprocess import preprocess_media, warp_media_group
from messages.progress import modify_progress, telegram_uploading
-from messages.utils import better_blockquote, delete_message, get_reply_to, remove_img_tag, smart_split, summay_media
+from messages.utils import better_blockquote, blockquote, delete_message, get_reply_to, remove_img_tag, smart_split, summay_media
from utils import to_int
@@ -31,8 +31,9 @@ async def send2tg(
send_from_user: str | None = None,
cooldown: float = 0,
caption_above: bool = False,
+ keep_file: bool = False,
**kwargs,
-) -> list[Message | None]:
+) -> list[Message]:
"""Send unlimited number of texts and media to Telegram.
Telegram Message Limitation:
@@ -54,6 +55,7 @@ async def send2tg(
send_from_user (str, optional): The user name to prefix the texts.
cooldown (float, optional): The interval between each media message. Defaults to 0.
caption_above (bool, optional): Show caption above the message media.
+ keep_file (bool, optional): Keep the media files after sending. Defaults to False.
kwargs: Other keyword arguments. In this function, we use:
show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
@@ -84,12 +86,23 @@ async def send2tg(
texts, _ = remove_img_tag(texts)
if kwargs.get("progress") and len(media) > 0:
await modify_progress(text=f"⏫正在上传:\n{summay_media(media)}", force_update=True, **kwargs)
- sent_messages: list[Message | None] = [] # return sent messages
+ sent_messages: list[Message] = [] # return sent messages
logger.trace(f"Sending {len(media)} media with {len(texts)} texts")
if len(media) == 0:
return await send_texts(client, target_chat, reply_parameters, tid, texts=texts, cooldown=cooldown)
if len(media) == 1:
- return await send_single_media(client, target_chat, reply_parameters, tid, media=media[0], texts=texts, cooldown=cooldown, caption_above=caption_above, **kwargs)
+ return await send_single_media(
+ client,
+ target_chat,
+ reply_parameters,
+ tid,
+ media=media[0],
+ texts=texts,
+ cooldown=cooldown,
+ caption_above=caption_above,
+ keep_file=keep_file,
+ **kwargs,
+ )
caption = (await smart_split(texts, CAPTION_LENGTH))[0]
remaining_texts = texts.removeprefix(caption)
@@ -109,14 +122,18 @@ async def send2tg(
group = await warp_media_group(batch)
sent_messages.extend(await send_media_group(client, target_chat, group, None, tid))
else: # last chunk: media <= 10, add caption here
- sent_messages.extend(await send2tg(client, message, target_chat, reply_msg_id=-1, thread_id=tid, texts=caption, media=batch, caption_above=caption_above, cooldown=cooldown, **kwargs))
+ sent_messages.extend(
+ await send2tg(
+ client, message, target_chat, reply_msg_id=-1, thread_id=tid, texts=caption, media=batch, caption_above=caption_above, cooldown=cooldown, keep_file=keep_file, **kwargs
+ )
+ )
await asyncio.sleep(cooldown)
if remaining_texts:
sent_messages.extend(await send_texts(client, target_chat, None, tid, texts=remaining_texts, cooldown=cooldown))
# clean up
for x in media:
for key in ["path", "media", "thumb", "audio", "photo", "video"]:
- if x.get(key) and Path(x[key]).is_file():
+ if x.get(key) and Path(x[key]).is_file() and not keep_file:
logger.trace(f"Deleting: {x[key]}")
Path(x[key]).unlink(missing_ok=True)
return sent_messages
@@ -130,8 +147,8 @@ async def send_texts(
*,
texts: str = "",
cooldown: float = 0,
-) -> list[Message | None]:
- sent_messages: list[Message | None] = []
+) -> list[Message]:
+ sent_messages: list[Message] = []
logger.trace(f"Sending {len(texts)} texts only")
for idx, msg in enumerate(await smart_split(texts.strip())):
if not msg:
@@ -142,14 +159,16 @@ async def send_texts(
if idx != 0:
reply_parameters = None
try:
- sent_messages.append(await client.send_message(target_chat, better_blockquote(msg), message_thread_id=thread_id, reply_parameters=reply_parameters))
+ message = await client.send_message(target_chat, better_blockquote(msg), message_thread_id=thread_id, reply_parameters=reply_parameters)
except FloodWait as e:
logger.warning(e)
await asyncio.sleep(e.value)
- sent_messages.append(await client.send_message(target_chat, better_blockquote(msg), message_thread_id=thread_id, reply_parameters=reply_parameters))
+ message = await client.send_message(target_chat, better_blockquote(msg), message_thread_id=thread_id, reply_parameters=reply_parameters)
except Exception as e:
logger.warning(f"send_texts: {e}")
await asyncio.sleep(cooldown)
+ if isinstance(message, Message):
+ sent_messages.append(message)
return sent_messages
@@ -163,9 +182,10 @@ async def send_single_media(
texts: str = "",
cooldown: float = 0,
caption_above: bool = False,
+ keep_file: bool = False,
**kwargs,
-) -> list[Message | None]:
- sent_messages: list[Message | None] = []
+) -> list[Message]:
+ sent_messages: list[Message] = []
logger.trace(f"Sending single media with {len(texts)} texts")
caption = (await smart_split(texts, CAPTION_LENGTH))[0]
remaining_texts = texts.removeprefix(caption)
@@ -210,15 +230,26 @@ async def send_single_media(
except FloodWait as e:
logger.warning(e)
await asyncio.sleep(e.value)
- return await send_single_media(client, target_chat, reply_parameters, thread_id, media=media, texts=texts, cooldown=cooldown, **kwargs)
+ return await send_single_media(
+ client,
+ target_chat,
+ reply_parameters,
+ thread_id,
+ media=media,
+ texts=texts,
+ cooldown=cooldown,
+ keep_file=keep_file,
+ **kwargs,
+ )
except Exception as e:
logger.warning(f"send_single_media: {e}")
- sent_messages.append(message)
+ if isinstance(message, Message):
+ sent_messages.append(message)
if remaining_texts:
sent_messages.extend(await send_texts(client, target_chat, None, thread_id, texts=remaining_texts, cooldown=cooldown))
for key in ["path", "thumb", "audio", "photo", "video"]:
- if media.get(key) and Path(media[key]).is_file():
+ if media.get(key) and Path(media[key]).is_file() and not keep_file:
logger.trace(f"Deleting: {media[key]}")
Path(media[key]).unlink(missing_ok=True)
return sent_messages
@@ -243,7 +274,8 @@ async def send_media_group(
if retry > 2:
return []
try:
- return await client.send_media_group(chat_id, media=media_group, reply_parameters=reply_parameters, message_thread_id=thread_id)
+ messages = await client.send_media_group(chat_id, media=media_group, reply_parameters=reply_parameters, message_thread_id=thread_id)
+ return [m for m in messages if isinstance(m, Message)]
except FloodWait as e:
logger.warning(e)
await asyncio.sleep(e.value)
@@ -255,7 +287,8 @@ async def send_media_group(
if retry > 2:
return []
try:
- return await client.copy_media_group(to_int(target_chat), from_chat_id=to_int(TID.TEMP), message_id=message_id, reply_parameters=reply_parameters, message_thread_id=thread_id)
+ messages = await client.copy_media_group(to_int(target_chat), from_chat_id=to_int(TID.TEMP), message_id=message_id, reply_parameters=reply_parameters, message_thread_id=thread_id)
+ return [m for m in messages if isinstance(m, Message)]
except FloodWait as e:
logger.warning(e)
await asyncio.sleep(e.value)
@@ -275,3 +308,47 @@ async def send_media_group(
# send directly
return await send(to_int(target_chat), media_group, reply_parameters)
+
+
+async def send_blockquote_texts(
+ client: Client,
+ message: Message,
+ target_chat: int | str = "",
+ reply_msg_id: int = 0,
+ thread_id: int = 0,
+ *,
+ texts: str = "",
+ cooldown: float = 1,
+ **kwargs,
+) -> list[Message]:
+ if not target_chat:
+ target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
+ target_chat = to_int(target_chat)
+ reply_parameters = get_reply_to(message.id, reply_msg_id)
+ tid = thread_id or get_thread_id(message)
+ if not texts.strip():
+ return []
+ texts, _ = remove_img_tag(texts)
+ sent_messages: list[Message] = []
+ cur_msg = None
+ logger.trace(f"Sending {len(texts)} blockquote texts:")
+ for msg in await smart_split(texts.strip()):
+ if not msg:
+ continue
+ try:
+ text = blockquote(msg)
+ if not isinstance(cur_msg, Message):
+ cur_msg = await client.send_message(target_chat, better_blockquote(text), message_thread_id=tid, reply_parameters=reply_parameters)
+ else:
+ cur_msg = await cur_msg.reply_text(better_blockquote(text), quote=True)
+ reply_parameters = get_reply_to(cur_msg.id, reply_msg_id)
+ except FloodWait as e:
+ logger.warning(e)
+ await asyncio.sleep(e.value)
+ cur_msg = await client.send_message(target_chat, better_blockquote(msg), message_thread_id=thread_id, reply_parameters=reply_parameters)
+ except Exception as e:
+ logger.warning(f"send_texts: {e}")
+ await asyncio.sleep(cooldown)
+ if isinstance(cur_msg, Message):
+ sent_messages.append(cur_msg)
+ return sent_messages
src/messages/utils.py
@@ -227,7 +227,7 @@ async def set_reaction(client: Client, message: Message, reaction: str | list[st
await client.set_reaction(message.chat.id, message.id)
-async def delete_message(message: Message | list[Message | None] | None):
+async def delete_message(message: Message | list | None):
if not message:
return
if not isinstance(message, list):
src/podcast/main.py
@@ -30,7 +30,6 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Chat, Message
-from ai.summary import summarize
from config import AI, PODCAST, PROXY
from database.github import gh_clean_assets
from database.r2 import get_cf_r2, set_cf_r2
@@ -41,7 +40,8 @@ from podcast.utils import HEADERS, clean_feed_url, feed_saved_target, get_pubdat
from podcast.xml import get_feed_title, parse_feed, save_xml, update_xml_desc
from preview.bilibili import get_bilibili_vinfo
from preview.youtube import get_youtube_vinfo
-from utils import bare_url, count_subtitles, https_url, nowdt, rand_number, seconds_to_hms, strings_list
+from summarize.summarize import summarize
+from utils import bare_url, convert2html, count_subtitles, https_url, nowdt, rand_number, seconds_to_hms, strings_list
from ytdlp.download import ytdlp_download
@@ -75,15 +75,16 @@ async def summary_pods(client: Client):
pubdate = f"{dt:%Y-%m-%d %H:%M:%S}"
caption = f"🎧[{feed_title}]({homepage})\n📝[{entry['title']}]({entry['link']})\n🕒{pubdate}\n⏳{duration} #️⃣字数: {count_subtitles(transcripts)}"
desc = glom(entry, Coalesce("content.0.value", "summary"), default="")
+ desc_html = desc if desc.startswith("<") else convert2html(desc)
+ prompt = f"该转录稿对应于播客栏目《{feed_title}》的一期节目,节目详情:\n标题: {entry['title']}\n日期: {pubdate}\n时长: {duration}\n节目简介: {desc}"
summary = await summarize(
- transcripts=transcripts,
- reference=f"该转录稿对应于播客栏目《{feed_title}》的一期节目,节目详情:\n标题: {entry['title']}\n日期: {pubdate}\n时长: {duration}\n节目简介: {desc}",
+ sources=[{"type": "system_prompt", "text": prompt}, {"type": "transcripts", "text": transcripts}],
model=AI.PODCAST_SUMMARY_MODEL_ALIAS,
title=entry["title"],
author=feed_title,
url=entry["link"],
date=dt,
- description=desc,
+ description={"emoji": "🎧", "name": "播客详情", "html": desc_html},
ttl="forever",
)
if telegraph_url := summary.get("telegraph_url"):
src/preview/arxiv.py
@@ -1,18 +1,25 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import contextlib
+import json
+import re
+import shutil
+from email.utils import parsedate_to_datetime
from pathlib import Path
import feedparser
+from bs4 import BeautifulSoup
from glom import Coalesce, glom
+from loguru import logger
from pyrogram.client import Client
from pyrogram.types import InputMediaDocument, Message
-from config import CAPTION_LENGTH, PROXY, TEXT_LENGTH
+from config import AI, CAPTION_LENGTH, DOWNLOAD_DIR, PROXY
from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import blockquote, smart_split
+from messages.sender import send2tg, send_blockquote_texts
+from messages.utils import smart_split
from networking import download_file, hx_req
+from summarize.summarize import summarize
+from utils import nowdt
HEADERS = {
"User-Agent": "feedparser/6.0.11 +https://github.com/kurtmckee/feedparser/",
@@ -20,53 +27,155 @@ HEADERS = {
}
-async def preview_arxiv(client: Client, message: Message, url: str, arxiv_id: str, **kwargs):
+async def preview_arxiv(
+ client: Client,
+ message: Message,
+ url: str,
+ arxiv_id: str,
+ *,
+ summary_arxiv: bool = True,
+ summary_arxiv_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
+ **kwargs,
+) -> None:
"""Preview arxiv in the message."""
- status_msg = None
- if kwargs.get("show_progress") and "progress" not in kwargs:
- res = await send2tg(client, message, texts=f"🔗正在解析arXiv链接\n{url}", **kwargs)
- kwargs["progress"] = res[0]
- status_msg = res[0]
+ status_msg = (await send2tg(client, message, texts=f"🔗正在解析arXiv链接\n{url}", **kwargs))[0]
kwargs["send_from_user"] = "" # disable @send_user
+ if not isinstance(status_msg, Message):
+ return
+
# First, get the PDF and send it.
- pdf = await download_file(f"https://arxiv.org/pdf/{arxiv_id}", suffix=".pdf", proxy=PROXY.ARXIV, stream=True)
+ pdf_url = f"https://arxiv.org/pdf/{arxiv_id}"
+ pdf = await download_file(pdf_url, suffix=".pdf", proxy=PROXY.ARXIV, stream=True)
if not pdf:
await modify_progress(status_msg, text="❌下载PDF失败", force_update=True)
return
- file_id = pdf
- if isinstance(status_msg, Message):
- status_msg = await status_msg.edit_media(file_name=f"{arxiv_id}.pdf", media=InputMediaDocument(file_id, caption=f"arXiv: [{arxiv_id}]({url})"))
- file_id = glom(status_msg, "document.file_id", default=pdf)
- api = f"https://export.arxiv.org/api/query?id_list={arxiv_id}"
- resp = await hx_req(api, headers=HEADERS, proxy=PROXY.ARXIV, rformat="text")
- if "hx_error" in resp:
+ status_msg = await status_msg.edit_media(media=InputMediaDocument(pdf, caption=f"arXiv: [{arxiv_id}]({url})"))
+ file_id = glom(status_msg, "document.file_id", default=pdf)
+ arxiv_info = await get_arxiv_meta(arxiv_id)
+ if not arxiv_info:
+ logger.error("❌arXiv API调用失败")
+ if summary_arxiv:
+ sources = await extract_arxiv_tex(arxiv_id) + [{"type": "file", "path": pdf, "mime_type": "application/pdf"}]
+ summary = await summarize(sources=sources, title=f"arXiv-{arxiv_id}", model=summary_arxiv_model, url=url, force_r2_page=True)
+ await send_blockquote_texts(client, status_msg, texts=summary["texts"], **kwargs)
+ Path(pdf).unlink(missing_ok=True)
return
- arxiv = feedparser.parse(resp["text"])
- entry = glom(arxiv, "entries.0", default={})
- title = glom(entry, "title", default="")
- updated = glom(entry, Coalesce("updated", "published"), default="")
- updated = updated.replace("T", " ").rstrip("Z")
- abstract = glom(entry, "summary", default="")
- comment = glom(entry, "arxiv_comment", default="")
- authors = ""
- for author in glom(arxiv, "entries.0.authors", default=[]):
- if name := author.get("name"):
- authors += f"{name}, "
- authors = authors.rstrip(", ")
+
+ title = arxiv_info["title"]
+ authors = arxiv_info["authors"]
+ updated = arxiv_info["updated"]
+ comment = arxiv_info["comment"]
+ abstract = arxiv_info["abstract"]
+
texts = f"📄**[{title}]({url})**\n👥{authors}\n🕒{updated}\n"
if comment:
texts += f"📝{comment}"
caption = (await smart_split(texts, CAPTION_LENGTH))[0]
- if isinstance(status_msg, Message):
- status_msg = await status_msg.edit_media(file_name=f"{arxiv_id}.pdf", media=InputMediaDocument(file_id, caption=caption))
- Path(pdf).unlink(missing_ok=True)
- # await modify_progress(status, del_status=True)
- if not isinstance(status_msg, Message):
- return
+ status_msg = await status_msg.edit_media(media=InputMediaDocument(file_id, caption=caption))
+ if not summary_arxiv:
+ await send_blockquote_texts(client, status_msg, texts=f"**Abstract**\n{abstract}", **kwargs)
+ else:
+ iframe = f'<iframe src="{pdf_url}" width="100%" height="800px" style="border: none; border-radius: 8px;"></iframe>'
+ ptag = f'<p style="text-align: center;"><a href="{pdf_url}" target="_blank">在新标签页中打开论文</a></p>'
+ sources = [{"type": "text", "text": json.dumps(arxiv_info)}] + await extract_arxiv_tex(arxiv_id) + [{"type": "file", "path": pdf, "mime_type": "application/pdf"}]
+ summary = await summarize(
+ sources=sources,
+ title=title,
+ model=summary_arxiv_model,
+ author=authors,
+ date=updated,
+ url=url,
+ description={"emoji": "📄", "name": "原始论文", "html": iframe + ptag},
+ force_r2_page=True,
+ )
+ await send_blockquote_texts(client, status_msg, texts=summary.get("texts", ""), **kwargs)
+
+ Path(pdf).unlink(missing_ok=True)
+
+
+async def get_arxiv_meta(arxiv_id: str) -> dict:
+ """Get arxiv metadata."""
+ # first, get the metadata from the standard arXxiv API
+ api = f"https://export.arxiv.org/api/query?id_list={arxiv_id}"
+ resp = await hx_req(api, headers=HEADERS, proxy=PROXY.ARXIV, rformat="text", timeout=3, max_retry=1)
+ if resp.get("status_code") == 200:
+ arxiv = feedparser.parse(resp["text"])
+ entry = glom(arxiv, "entries.0", default={})
+ title = glom(entry, "title", default="")
+ published = glom(entry, "published", default="")
+ updated = glom(entry, Coalesce("updated", "published"), default="")
+ published = published.replace("T", " ").rstrip("Z")
+ updated = updated.replace("T", " ").rstrip("Z")
+ abstract = glom(entry, "summary", default="")
+ comment = glom(entry, "arxiv_comment", default="")
+ authors = ""
+ for author in glom(entry, "authors", default=[]):
+ if name := author.get("name"):
+ authors += f"{name}, "
+ authors = authors.rstrip(", ")
+ return {"title": title, "authors": authors, "published": published, "updated": updated, "comment": comment, "abstract": abstract}
+
+ logger.warning("❌arXiv standard API调用失败,回退到Open Archives Initiative")
+ clean_id = re.sub(r"v\d+$", "", arxiv_id)
+ api = f"https://oaipmh.arxiv.org/oai?verb=GetRecord&identifier=oai:arXiv.org:{clean_id}&metadataPrefix=arXivRaw"
+ resp = await hx_req(api, headers=HEADERS, proxy=PROXY.ARXIV, rformat="text", timeout=3, max_retry=1)
+ if resp.get("status_code") == 200:
+ soup = BeautifulSoup(resp["text"], "xml")
+ title = glom(soup, "title.text", default="").strip().replace("\n", " ")
+ authors = glom(soup, "authors.text", default="").strip()
+ abstract = glom(soup, "abstract.text", default="").strip()
+ comment = glom(soup, "comments.text", default="").strip().replace("\n", " ")
+ versions = soup.find_all("version")
+ pub_dt = nowdt("UTC")
+ upd_dt = nowdt("UTC")
+ if published := glom(versions, "0.date.text", default=""):
+ pub_dt = parsedate_to_datetime(published)
+ if updated := glom(versions, "-1.date.text", default=""):
+ upd_dt = parsedate_to_datetime(updated)
+ return {
+ "title": title,
+ "authors": authors,
+ "published": pub_dt.strftime("%Y-%m-%d %H:%M:%S"),
+ "updated": upd_dt.strftime("%Y-%m-%d %H:%M:%S"),
+ "comment": comment,
+ "abstract": abstract,
+ }
+ return {}
+
+
+async def extract_arxiv_tex(arxiv_id: str) -> list[dict]:
+
+ def remove_comments(content: str) -> str:
+ content = re.sub(r"(?<!\\)%.*$", "", content, flags=re.MULTILINE)
+ return re.sub(r"\n\s*\n", "\n\n", content)
+
+ arxiv_dir = Path(DOWNLOAD_DIR) / arxiv_id
+ save_path = Path(DOWNLOAD_DIR) / f"{arxiv_id}.tar.gz"
+ try:
+ await download_file(f"https://arxiv.org/src/{arxiv_id}", save_path, proxy=PROXY.ARXIV, stream=True)
+ shutil.rmtree(arxiv_dir, ignore_errors=True)
+ arxiv_dir.mkdir(parents=True, exist_ok=True)
+ shutil.unpack_archive(save_path, arxiv_dir)
+
+ main_tex = [f for f in arxiv_dir.rglob("*") if f.is_file and f.name == "main.tex"]
+ tex_files = [f for f in arxiv_dir.rglob("*") if f.is_file and f.suffix == ".tex"]
+ tex_files = sorted(tex_files, key=lambda x: x.name)
+ bib_files = [f for f in arxiv_dir.rglob("*") if f.is_file and f.suffix == ".bib"]
- abstract = f"**Abstract**\n{abstract}"
- with contextlib.suppress(Exception):
- for txt in await smart_split(abstract, TEXT_LENGTH):
- status_msg = await status_msg.reply_text(blockquote(txt), quote=True)
+ sources = []
+ files = main_tex + tex_files + bib_files
+ for f in files:
+ name = f.relative_to(arxiv_dir).name
+ content = f.read_text().strip()
+ if f.suffix == ".tex":
+ content = remove_comments(content)
+ sources.append({"type": "text", "text": json.dumps({"filename": name, "content": content})})
+ except Exception as e:
+ logger.error(f"❌arXiv {arxiv_id} 提取 tex 失败: {e}")
+ sources = []
+ finally:
+ Path(save_path).unlink(missing_ok=True)
+ shutil.rmtree(arxiv_dir, ignore_errors=True)
+ return sources
src/preview/douyin.py
@@ -13,8 +13,9 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.utils import trim_none
from bridge.social import send_to_social_media_bridge
-from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TOKEN, TZ
+from config import AI, API, DOWNLOAD_DIR, PROVIDER, PROXY, TOKEN, TZ
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
@@ -22,7 +23,9 @@ from messages.sender import send2tg
from messages.utils import blockquote, summay_media
from networking import download_file, download_first_success_urls, download_media, hx_req
from others.emoji import emojify
-from utils import rand_number, readable_count, true
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import nowstr, rand_number, readable_count, true
async def preview_douyin(
@@ -34,6 +37,8 @@ async def preview_douyin(
douyin_provider: str = PROVIDER.DOUYIN,
douyin_comments_provider: str = PROVIDER.DOUYIN_COMMENTS,
*,
+ summary_douyin: bool = False,
+ summary_douyin_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
show_author: bool = True,
show_pubdate: bool = True,
show_statistics: bool = True,
@@ -90,9 +95,23 @@ async def preview_douyin(
texts += f"\n{data['desc']}"
comments = await get_comments(data["aweme_id"], platform, douyin_comments_provider)
- sent_messages = await send2tg(client, message, texts=emojify(texts) + comments, media=data.get("media", []), **kwargs)
+ sent_messages = await send2tg(client, message, texts=emojify(texts) + comments, media=data.get("media", []), keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary douyin
+ # find the first message that has a caption
+ caption_msg = None
+ index = -1
+ for idx, m in enumerate(sent_messages):
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ index = idx
+ break
+ if summary_douyin and caption_msg:
+ edited_msg = await summarize_douyin(caption_msg, data, data.get("media", []), summary_douyin_model, url)
+ sent_messages[index] = edited_msg
await save_messages(messages=sent_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in data.get("media", [])]
async def parse_via_direct(url: str = "", platform: str = "douyin", proxy: str | None = None, **kwargs) -> tuple[bool, dict]:
@@ -298,3 +317,46 @@ async def get_comments(aweme_id: str = "", platform: str = "douyin", douyin_comm
comments_str += "💬**点此展开评论区**:"
comments_str += f"\n💬**{cmt['name']}**{cmt['region']}: {cmt['text']}"
return blockquote(comments_str)
+
+
+async def summarize_douyin(message: Message, douyin: dict, media_list: list[dict], model: str, url: str) -> Message:
+ """Generate source for AI summary."""
+ data = {
+ "platform": "Tiktok" if "tiktok.com" in url else "抖音",
+ "author_name": douyin.get("author"),
+ "url": url,
+ "description": douyin.get("desc"),
+ }
+ if douyin.get("create_time"):
+ dt = datetime.fromtimestamp(douyin["create_time"]).astimezone(ZoneInfo(TZ))
+ data["created_at"] = f"{dt:%Y-%m-%d %H:%M:%S}"
+ data = trim_none(data)
+
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None # always summarize video
+ min_video_duration = 120
+ sources.append({"type": "video", "path": media["video"]})
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ author_name = douyin.get("author", "Anonymous")
+ pid = douyin["aweme_id"]
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=f"🎶{author_name} - {pid}",
+ author=author_name,
+ url=url,
+ date=data.get("created_at") or nowstr(TZ),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/instagram.py
@@ -1,15 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import json
+from pathlib import Path
from typing import Literal
from bs4 import BeautifulSoup
-from glom import flatten, glom
+from glom import Coalesce, flatten, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.utils import trim_none
from bridge.social import send_to_social_media_bridge
-from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN
+from config import AI, API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
@@ -17,7 +20,9 @@ from messages.sender import send2tg
from messages.utils import blockquote, summay_media
from multimedia import is_valid_video_or_audio, validate_img
from networking import download_file, download_media, hx_req
-from utils import readable_count, true, ts_to_dt
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import nowstr, readable_count, true, ts_to_dt
async def preview_instagram(
@@ -31,6 +36,8 @@ async def preview_instagram(
username: str = "",
instagram_provider: str = PROVIDER.INSTAGRAM,
instagram_comments: bool = True,
+ summary_instagram: bool = False,
+ summary_instagram_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
show_author: bool = True,
show_pubdate: bool = True,
show_statistics: bool = True,
@@ -120,9 +127,23 @@ async def preview_instagram(
await modify_progress(text=f"⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
media = await download_media(media, **kwargs)
- sent_messages = await send2tg(client, message, texts=texts.strip() + comments, media=media, **kwargs)
+ sent_messages = await send2tg(client, message, texts=texts.strip() + comments, media=media, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary instagram
+ # find the first message that has a caption
+ caption_msg = None
+ index = -1
+ for idx, m in enumerate(sent_messages):
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ index = idx
+ break
+ if summary_instagram and caption_msg:
+ edited_msg = await summarize_instagram(caption_msg, data, media, summary_instagram_model, url)
+ sent_messages[index] = edited_msg
await save_messages(messages=sent_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in data.get("media", [])]
async def preview_story(client: Client, message: Message, data: dict, username: str, post_id: str, db_key: str, **kwargs):
@@ -209,3 +230,43 @@ async def preview_ddinstagram(client: Client, message: Message, url: str, post_t
await send2tg(client, message, texts=texts, media=[media], **kwargs)
await modify_progress(del_status=True, **kwargs)
+
+
+async def summarize_instagram(message: Message, info: dict, media_list: list[dict], model: str, url: str) -> Message:
+ """Generate source for AI summary."""
+ data = {
+ "platform": "Instagram",
+ "author_name": glom(info, "owner.full_name", default=None),
+ "url": url,
+ "description": glom(info, "edge_media_to_caption.edges.0.node.text", default=None),
+ }
+ if ts := glom(data, "edge_media_to_caption.edges.0.node.created_at", default=0):
+ data["created_at"] = f"{ts_to_dt(ts):%Y-%m-%d %H:%M:%S}"
+ data = trim_none(data)
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None # always summarize video
+ min_video_duration = 120
+ sources.append({"type": "video", "path": media["video"]})
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ author_name = data.get("author_name", "Anonymous")
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=f"🏞{author_name} - Instagram",
+ author=author_name,
+ url=url,
+ date=data.get("created_at") or nowstr(TZ),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/twitter.py
@@ -1,24 +1,30 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import copy
+import json
import re
from datetime import UTC, datetime
+from pathlib import Path
from zoneinfo import ZoneInfo
from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
-from pyrogram.types import Message
+from pyrogram.types import LinkPreviewOptions, Message
from bridge.social import send_to_social_media_bridge
-from config import API, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
+from config import AI, API, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
from messages.sender import send2tg
-from messages.utils import blockquote, remove_img_tag, summay_media
+from messages.utils import blockquote, smart_split, summay_media
from networking import download_file, download_media, flatten_rediercts, hx_req
-from utils import convert2html, readable_count, remove_consecutive_newlines, remove_none_values, split_parts, true
+from preview.utils import add_summary_url
+from publish import publish_telegraph
+from summarize.summarize import summarize
+from utils import nowstr, readable_count, remove_consecutive_newlines, remove_none_values, split_parts, true
class APIError(Exception):
@@ -38,6 +44,8 @@ async def preview_twitter(
show_pubdate: bool = True,
show_device: bool = False,
show_statistics: bool = True,
+ summary_twitter: bool = False,
+ summary_twitter_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
**kwargs,
):
"""Preview twitter link in the message.
@@ -147,7 +155,7 @@ async def preview_twitter(
# 被回复主推
if master_info:
if true(show_author) and master_info.get("author"):
- msg += f"\n🕊**[{master_info['author']}](https://x.com/{master_info['handle']}/status/{master_info['post_id']})**"
+ msg += f'\n🕊<a href="https://x.com/{master_info["handle"]}/status/{master_info["post_id"]}"><b>{master_info["author"]}</b></a>'
if true(show_pubdate) and master_info.get("time"):
msg += f"\n🕒{master_info['time']}"
if part_strs["first"]:
@@ -172,7 +180,7 @@ async def preview_twitter(
if master_info:
msg += "\n⤴️"
if true(show_author) and this_info.get("author"):
- msg += f"\n🕊**[{this_info['author']}]({url})**"
+ msg += f'\n🕊<a href="{url}"><b>{this_info["author"]}</b></a>'
msg = msg.replace("\n⤴️\n🕊", "\n⤴️")
if true(show_pubdate) and this_info.get("time"):
msg += f"\n🕒{this_info['time']}"
@@ -200,7 +208,7 @@ async def preview_twitter(
msg = remove_twitter_suffix(msg, post_id=quote_info["post_id"], same_id_only=True)
msg += "\n//"
if true(show_author) and quote_info.get("author"):
- msg += f"\n🕊**[{quote_info['author']}]({quote_x_url})**"
+ msg += f'\n🕊<a href="{quote_x_url}"><b>{quote_info["author"]}</b></a>'
msg = msg.replace("\n//\n", "\n//")
if true(show_pubdate) and quote_info.get("time"):
msg += f"\n🕒{quote_info['time']}"
@@ -217,9 +225,40 @@ async def preview_twitter(
await modify_progress(text=f"⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
media = await download_media(media, **kwargs)
- sent_messages = await send2tg(client, message, texts=msg.strip(), media=media, **kwargs)
+ sent_messages = []
+ if master_info.get("is_article") or this_info.get("is_article") or quote_info.get("is_article"):
+ msg = msg.replace("<blockquote>", f"\n{'—' * 10}\n").replace("</blockquote>", f"\n{'—' * 10}\n")
+ msg = msg.replace("<pre>", "</blockquote><pre>").replace("</pre>", "</pre><blockquote expandable>")
+ article_url = master_info.get("article_url") or this_info.get("article_url") or quote_info.get("article_url") or url
+ cur_msg = None
+ link_preview = LinkPreviewOptions(is_disabled=False, show_above_text=True, url=article_url)
+ for m in await smart_split(msg):
+ if not isinstance(cur_msg, Message):
+ cur_msg = await message.reply_text(text=f"<blockquote expandable>{m}</blockquote>", quote=True, link_preview_options=link_preview)
+ else:
+ cur_msg = await cur_msg.reply_text(f"<blockquote expandable>{m}</blockquote>", quote=True)
+ if isinstance(cur_msg, Message):
+ sent_messages.append(cur_msg)
+ await asyncio.sleep(1)
+ sent_messages.extend(await send2tg(client, cur_msg or message, media=media, keep_file=True, **kwargs))
+ else:
+ sent_messages = await send2tg(client, message, texts=msg.strip(), media=media, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary twitter
+ # find the first message that has a caption
+ caption_msg = None
+ index = -1
+ for idx, m in enumerate(sent_messages):
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ index = idx
+ break
+ if summary_twitter and caption_msg:
+ edited_msg = await summarize_twitter(caption_msg, this_info, master_info, quote_info, media, summary_twitter_model)
+ sent_messages[index] = edited_msg
await save_messages(messages=sent_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in media]
@cache.memoize(ttl=30)
@@ -339,7 +378,7 @@ async def get_tweet_info_via_fxtwitter(url: str = "", handle: str = "", post_id:
if not handle or not post_id:
handle = url.split("/")[-3]
post_id = url.rsplit("/", maxsplit=1)[-1]
- api_url = f"{API.FXTWITTER}/{handle}/status/{post_id}"
+ api_url = f"{API.FXTWITTER}/{handle}/status/{post_id}/zh"
logger.info(f"Twitter preview via fxtwitter: {api_url}")
headers = {"user-agent": TELEGRAM_UA}
resp = await hx_req(api_url, headers=headers, proxy=PROXY.TWITTER)
@@ -350,9 +389,13 @@ async def get_tweet_info_via_fxtwitter(url: str = "", handle: str = "", post_id:
if data.get("article"):
data |= parse_article(data["article"])
+ author = glom(data, "author.screen_name", default="Anonymous")
+ url = f"https://x.com/{author}/status/{post_id}"
+ data["article_url"] = await publish_telegraph(title=data["title"], author=author, url=url, html=data["html"])
+ data["text"] = data["text"].replace(f"<h1>{data['title']}</h1>", f'<h1><a href="{data["article_url"]}">{data["title"]}</a></h1>')
info = {"handle": glom(data, "author.screen_name", default=handle), "post_id": data.get("id", post_id)}
- media = glom(data, Coalesce("media.all", "image_urls"), default=[])
+ media = glom(data, "media.all", default=[])
for x in media:
if x.get("type", "") == "video": # this is a m3u8 url, choose mp4 instead
m3u8_url = x.get("url", "")
@@ -379,7 +422,10 @@ async def get_tweet_info_via_fxtwitter(url: str = "", handle: str = "", post_id:
if ts := data.get("created_timestamp", ""):
dt = datetime.fromtimestamp(round(float(ts)), tz=UTC).astimezone(ZoneInfo(TZ))
info["time"] = f"{dt:%Y-%m-%d %H:%M:%S}"
- info["texts"] = data.get("text", "")
+ info["texts"] = glom(data, Coalesce("translation.text", "text"), default="")
+ info["html"] = data.get("html", "")
+ info["is_article"] = data.get("is_article", False)
+ info["article_url"] = data.get("article_url")
info["device"] = data.get("source", "").removeprefix("Twitter for").removeprefix("Twitter").removesuffix("App").strip().removesuffix("Web")
info["replying_to_user"] = data.get("replying_to", "")
info["replying_post_id"] = data.get("replying_to_status", "")
@@ -519,14 +565,16 @@ def parse_article(article: dict) -> dict:
style_ = style["style"].lower()
start = style["offset"]
end = start + style["length"]
- tag = ""
+ tag_start = ""
if style_ == "bold":
- tag = "**"
+ tag_start = "<b>"
+ tag_end = "</b>"
elif style_ == "italic":
- tag = "*"
- if tag:
- prefixes[start].append(tag)
- suffixes[end].insert(0, tag) # 使用 insert(0) 确保闭合标签以正确的嵌套顺序反向闭合
+ tag_start = "<i>"
+ tag_end = "</i>"
+ if tag_start:
+ prefixes[start].append(tag_start)
+ suffixes[end].insert(0, tag_end) # 使用 insert(0) 确保闭合标签以正确的嵌套顺序反向闭合
formatted_text = ""
for i in range(text_len + 1):
@@ -536,6 +584,22 @@ def parse_article(article: dict) -> dict:
formatted_text += text[i]
return formatted_text
+ html = ""
+ if cover_url := glom(article, "cover_media.media_info.original_img_url", default=""):
+ html += f'\n<img src="{cover_url}" alt="Cover" />'
+
+ media_list = []
+ for media in article.get("media_entities", []):
+ if variants := [x for x in glom(media, "media_info.variants", default=[]) if x.get("content_type") == "video/mp4"]: # video
+ variants = sorted(variants, key=lambda x: x.get("bitrate", 0), reverse=True)
+ if video_url := glom(variants, "0.url", default=""):
+ media_list.append({"url": video_url, "type": "video", "media_id": media.get("media_id")})
+ elif img_url := glom(media, "media_info.original_img_url", default=""):
+ media_list.append({"url": img_url, "type": "photo", "media_id": media.get("media_id")})
+
+ entity_map = glom(article, "content.entityMap", default={})
+ entity_dict = {str(x["key"]): x["value"] for x in entity_map} if isinstance(entity_map, list) else {str(k): v for k, v in entity_map.items()}
+
def parse_atomic(entities: list[dict]) -> str:
"""Parse atomic block."""
if not entities:
@@ -546,51 +610,133 @@ def parse_article(article: dict) -> dict:
e_type = entity.get("type", "").upper()
if e_type == "MEDIA":
media_id = glom(entity, "data.mediaItems.0.mediaId", default="")
- if img_url := media_dict.get(str(media_id)):
- texts += f""
+ if img_url := next((x["url"] for x in media_list if x["type"] == "photo" and x["media_id"] == media_id), None):
+ texts += f'\n<img src="{img_url}" alt="IMG-{media_id}" />'
+ elif video_url := next((x["url"] for x in media_list if x["type"] == "video" and x["media_id"] == media_id), None):
+ texts += f'\n<video src="{video_url}" />'
elif e_type == "DIVIDER":
texts += "\n"
elif e_type == "TWEET":
if tweet_id := glom(entity, "data.tweetId", default=""):
- texts += f"[QuoteTweet](https://x.com/i/status/{tweet_id})"
+ texts += f'\n<a href="https://x.com/i/status/{tweet_id}">QuoteTweet</a>'
elif e_type == "MARKDOWN":
- texts += glom(entity, "data.markdown", default="")
- return texts
-
- markdown = ""
- if title := article.get("title"):
- markdown += f"\n\n# {title}"
- if cover_url := glom(article, "cover_media.media_info.original_img_url", default=""):
- markdown += f"\n\n"
-
- media_dict: dict = {} # {media_id: media_url} # currently, articles in X only support images
- for media in article.get("media_entities", []):
- media_dict[str(media.get("media_id"))] = glom(media, "media_info.original_img_url", default="")
-
- entity_map = glom(article, "content.entityMap", default={})
- entity_dict = {str(x["key"]): x["value"] for x in entity_map} if isinstance(entity_map, list) else {str(k): v for k, v in entity_map.items()}
+ markdown = glom(entity, "data.markdown", default="").strip("`")
+ lang, raw = markdown.split("\n", maxsplit=1)
+ if lang:
+ texts += f'\n<pre language="{lang}">{raw}</pre>'
+ else:
+ texts += f"\n<pre>{markdown}</pre>"
+ return texts.strip()
# blocks
for block in glom(article, "content.blocks", default=[]):
text = inline_style(block.get("text"), block.get("inlineStyleRanges"))
entities = block.get("entityRanges", [])
- match block.get("type"):
- case "header-one" | "header-two" | "header-three" | "header-four":
- markdown += f"\n\n**{text}**"
+
+ block_type = block.get("type")
+ match block_type:
+ case "header-one":
+ html += f"\n<h1>{text}</h1>"
+ case "header-two":
+ html += f"\n<h2>{text}</h2>"
+ case "header-three":
+ html += f"\n<h3>{text}</h3>"
+ case "header-four":
+ html += f"\n<h4>{text}</h4>"
case "blockquote":
- markdown += f"\n\n> {text}"
+ html += f"\n<blockquote>{text}</blockquote>"
case "ordered-list-item" | "unordered-list-item":
- markdown += f"\n\n• {text}"
+ html += f"\n・{text}"
case "atomic":
- markdown += f"\n\n{parse_atomic(entities)}"
+ html += f"\n{parse_atomic(entities)}"
case _:
- markdown += f"\n\n{text}" if text else ""
+ html += f"\n<p>{text}</p>" if text else ""
- markdown_no_img, image_urls = remove_img_tag(markdown)
+ # form ordered media list
+ media = []
+ # 匹配img标签的正则表达式(支持单引号和双引号)
+ img_pattern = re.compile(r'<img\s+[^>]*?src\s*=\s*["\'](.*?)["\'][^>]*?>', re.IGNORECASE)
+ # 匹配video标签的正则表达式(支持单引号和双引号)
+ video_pattern = re.compile(r'<video\s+[^>]*?src\s*=\s*["\'](.*?)["\'][^>]*?>', re.IGNORECASE)
+ for line in html.splitlines():
+ if match_img := img_pattern.search(line):
+ media.append({"url": match_img.group(1), "type": "photo"})
+ if match_vid := video_pattern.search(line):
+ media.append({"url": match_vid.group(1), "type": "video"})
+
+ # 移除所有img和video标签
+ clean_html = img_pattern.sub("", html)
+ clean_html = video_pattern.sub("", clean_html)
return {
- "markdown": remove_consecutive_newlines(markdown).strip(),
- "text": remove_consecutive_newlines(markdown_no_img).strip(),
- "image_urls": image_urls,
- "html": convert2html(markdown),
- "media": {"all": [{"url": url, "type": "photo"} for url in image_urls]},
+ "is_article": True,
+ "text": remove_consecutive_newlines(clean_html).strip(),
+ "image_urls": img_pattern.findall(html),
+ "video_urls": video_pattern.findall(html),
+ "html": html,
+ "media": {"all": media},
+ "title": article.get("title", "Twitter Article"),
}
+
+
+async def summarize_twitter(message: Message, this_info: dict, master_info: dict, quote_info: dict, media_list: list[dict], model: str) -> Message:
+ """Generate source for AI summary."""
+
+ def trim(obj: dict) -> dict:
+ if isinstance(obj, dict):
+ return {k: trim(v) for k, v in obj.items() if v not in ["", None]}
+ if isinstance(obj, list):
+ return [trim(item) for item in obj if item not in ["", None]] # ty:ignore[invalid-return-type]
+ return obj
+
+ def cleanup(info: dict) -> dict:
+ cleaned = {}
+ keep_keys = {"author_name": "author", "created_at": "time", "content": ["markdown", "texts"], "post_id": "post_id", "handle": "handle"}
+ for k, v in keep_keys.items():
+ target = trim(info.copy())
+ spec = v if isinstance(v, str) else Coalesce(*v)
+ cleaned[k] = glom(target, spec, default=None)
+ if cleaned.get("post_id") and cleaned.get("handle"):
+ cleaned["url"] = f"https://x.com/{cleaned['handle']}/status/{cleaned['post_id']}"
+ cleaned.pop("handle")
+ return trim(cleaned)
+
+ def get_key(cleaned: dict, key: str, *, default: str = "") -> str:
+ return glom(cleaned, Coalesce(key, f"quote_tweet.{key}", f"replying_to_tweet.{key}"), default=default)
+
+ article = {"platform": "Twitter / X"} | cleanup(this_info)
+ if cleanup(quote_info):
+ article |= {"quote_tweet": cleanup(quote_info)}
+ if cleanup(master_info):
+ article |= {"replying_to_tweet": cleanup(master_info)}
+
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None
+ min_video_duration = 120 # skip short videos less than 3 minutes
+ sources.append({"type": "video", "path": media["video"]})
+ if this_info.get("markdown") or quote_info.get("markdown") or master_info.get("markdown"):
+ min_text_length = None # This is twitter article
+ min_video_duration = None
+ sources.append({"type": "text", "text": json.dumps(article, ensure_ascii=False)})
+ author_name = get_key(article, "author_name", default="Anonymous")
+ pid = get_key(article, "post_id", default="")
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=f"🕊{author_name} - {pid}",
+ author=author_name,
+ url=get_key(article, "url", default="https://x.com"),
+ date=get_key(article, "time", default=nowstr(TZ)),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/utils.py
@@ -2,6 +2,12 @@
# -*- coding: utf-8 -*-
import re
+from glom import glom
+from loguru import logger
+from pyrogram.enums import ParseMode
+from pyrogram.errors import MediaCaptionTooLong, MessageTooLong
+from pyrogram.types import LinkPreviewOptions, Message
+
def has_markdown_img(text: str) -> bool:
"""Check if the text contains markdown img format.
@@ -10,3 +16,40 @@ def has_markdown_img(text: str) -> bool:
"""
pattern = r"!\[.*?\]\(.*?\)"
return bool(re.search(pattern, text))
+
+
+async def add_summary_url(url: str, message: Message) -> Message:
+ """Add telegraph url to caption.
+
+ First, we try to prepend the url to {time} # 🕒2026-05-26 22:40:57
+ If failed, we try to add the url to {time}
+ If failed, we reply the url to the message.
+ """
+
+ def gen_captions() -> tuple[str, str]:
+ html = glom(message, "content.html", default="")
+ lines = html.split("\n")
+ time = ""
+ day = ""
+ for i, line in enumerate(lines):
+ if matched := re.match(r"^🕒(\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2}", line):
+ lines[i] = f'<a href="{url}"><b>🤖AI导读</b></a> ' + line
+ time = matched.group(0)
+ day = matched.group(1)
+ break
+ prepend = "\n".join(lines)
+ added = html.replace(time, f'<a href="{url}"><b>🤖AI导读</b></a> 🕒{day}') if day else html
+ return prepend, added
+
+ prepend, added = gen_captions()
+ logger.trace(f"Add summary url {url} to {message.link}")
+ try:
+ message = await message.edit_caption(prepend, parse_mode=ParseMode.HTML)
+ except (MediaCaptionTooLong, MessageTooLong):
+ logger.warning("Caption is too long, use added caption")
+ message = await message.edit_caption(added, parse_mode=ParseMode.HTML)
+ except Exception as e:
+ logger.warning(f"Failed to add summary url {url} to caption: {e}")
+ link_preview = LinkPreviewOptions(is_disabled=False, show_above_text=True, url=url)
+ await message.reply_text(f'<a href="{url}"><b>🤖AI导读</b></a>', quote=True, parse_mode=ParseMode.HTML, link_preview_options=link_preview)
+ return message
src/preview/v2ex.py
@@ -1,24 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import json
import re
from datetime import UTC, datetime
from pathlib import Path
from zoneinfo import ZoneInfo
-from glom import glom
+from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import PROXY, TELEGRAM_UA, TOKEN, TZ
+from ai.utils import trim_none
+from config import AI, PROXY, TELEGRAM_UA, TOKEN, TZ
+from database.kv import get_cf_kv, set_cf_kv
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import summay_media
from networking import download_file, download_media, hx_req
-from utils import number_to_emoji
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import nowstr, number_to_emoji
-async def preview_v2ex(client: Client, message: Message, url: str = "", topic_id: str = "", **kwargs):
+async def preview_v2ex(
+ client: Client,
+ message: Message,
+ url: str = "",
+ topic_id: str = "",
+ *,
+ summary_v2ex: bool = True,
+ summary_v2ex_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
+ **kwargs,
+):
"""Preview v2ex link in the message.
Args:
@@ -31,7 +45,11 @@ async def preview_v2ex(client: Client, message: Message, url: str = "", topic_id
res = await send2tg(client, message, texts=f"🔗正在解析V2ex链接\n{url}", **kwargs)
kwargs["progress"] = res[0]
logger.info(f"v2ex link preview for {url}")
- headers = {"Authorization": f"Bearer {TOKEN.V2EX}"}
+ token = await refresh_v2ex_token()
+ if not token:
+ await modify_progress(text="❌V2EX Token已失效, 请手动创建", force_update=True, **kwargs)
+ return
+ headers = {"Authorization": f"Bearer {token}"}
topic_api = f"https://www.v2ex.com/api/v2/topics/{topic_id}"
resp = await hx_req(topic_api, proxy=PROXY.V2EX, headers=headers, check_kv={"success": True, "result.id": topic_id})
if error := resp.get("error"):
@@ -41,9 +59,9 @@ async def preview_v2ex(client: Client, message: Message, url: str = "", topic_id
author_url = f"https://www.v2ex.com/member/{author}"
title = glom(resp, "result.title", default="Title")
ts = glom(resp, "result.created", default=0)
- texts = f"💻[{author}]({author_url})\n"
+ texts = f"💻**[{author}]({author_url})**\n"
texts += f"🕒{datetime.fromtimestamp(ts, tz=UTC).astimezone(ZoneInfo(TZ)).strftime('%Y-%m-%d %H:%M:%S')}\n"
- texts += f"📝[{title}]({url})\n"
+ texts += f"📝**[{title}]({url})**\n"
content, img_urls = extract_and_remove_images_regex(glom(resp, "result.content", default=""))
texts += content + "\n"
if supplements := glom(resp, "result.supplements", default=[]):
@@ -54,8 +72,19 @@ async def preview_v2ex(client: Client, message: Message, url: str = "", topic_id
if media:
await modify_progress(text=f"⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
media = await download_media(media, **kwargs)
- await send2tg(client, message, texts=texts, media=media, **kwargs)
+ sent_messages = await send2tg(client, message, texts=texts, media=media, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary v2ex
+ # find the first message that has a caption
+ caption_msg = None
+ for m in sent_messages:
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ break
+ if summary_v2ex and caption_msg:
+ await summarize_v2ex(caption_msg, resp.get("result", {}), media, summary_v2ex_model, url)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in media]
def extract_and_remove_images_regex(markdown_text: str) -> tuple[str, list[str]]:
@@ -86,3 +115,107 @@ async def download_imgs(img_urls: list[str]) -> list[dict]:
else:
media.append({"photo": download_file(img_url, proxy=PROXY.V2EX)})
return media
+
+
+async def refresh_v2ex_token() -> str:
+ """Refresh v2ex token.
+
+ V2EX API token expires after 180 days.
+ """
+
+ async def check_token(token: str) -> tuple[bool, int]:
+ resp = await hx_req(
+ "https://www.v2ex.com/api/v2/token",
+ proxy=PROXY.V2EX,
+ headers={"Authorization": f"Bearer {token}"},
+ check_kv={"success": True},
+ max_retry=0,
+ )
+ return bool(glom(resp, "result.token", default=None)), glom(resp, "result.expiration", default=0)
+
+ async def create_token(token: str) -> str:
+ resp = await hx_req(
+ "https://www.v2ex.com/api/v2/tokens",
+ method="POST",
+ proxy=PROXY.V2EX,
+ json_data={"scope": "everything", "expiration": 15552000},
+ headers={"Authorization": f"Bearer {token}"},
+ max_retry=0,
+ check_kv={"success": True},
+ )
+ return glom(resp, "result.token", default="")
+
+ if TOKEN.V2EX:
+ valid, ttl = await check_token(TOKEN.V2EX)
+ if valid:
+ if ttl < 86400 * 3: # 3天之内
+ new_token = await create_token(TOKEN.V2EX)
+ logger.warning("V2EX Token即将失效, 正在重新创建...")
+ await set_cf_kv("v2ex_token", {"token": new_token})
+ return TOKEN.V2EX
+ logger.warning("V2EX Token已失效, 从KV获取...")
+ token = (await get_cf_kv("v2ex_token")).get("token", "")
+ valid, ttl = await check_token(token)
+ if valid:
+ if ttl < 86400 * 3: # 3天之内
+ new_token = await create_token(token)
+ logger.warning("V2EX Token即将失效, 正在重新创建...")
+ await set_cf_kv("v2ex_token", {"token": new_token})
+ return token
+ return ""
+
+
+async def summarize_v2ex(message: Message, v2ex: dict, media_list: list[dict], model: str, url: str) -> Message:
+ """Generate source for AI summary."""
+
+ def date_str(ts: int) -> str | None:
+ if not ts:
+ return None
+ return f"{datetime.fromtimestamp(ts, tz=UTC).astimezone(ZoneInfo(TZ)).strftime('%Y-%m-%d %H:%M:%S')}"
+
+ data = {
+ "platform": "V2EX",
+ "title": glom(v2ex, "title", default=""),
+ "author_name": glom(v2ex, "member.username", default="Anonymous"),
+ "url": url,
+ "content": glom(v2ex, Coalesce("content_rendered", "content"), default=None),
+ }
+
+ if ts := glom(v2ex, "created", default=0):
+ data["created_at"] = date_str(ts)
+ supplements = [
+ {
+ "created_at": date_str(supp.get("created", 0)),
+ "content": glom(supp, Coalesce("content_rendered", "content"), default=None),
+ }
+ for supp in glom(v2ex, "supplements", default=[])
+ ]
+ if supplements:
+ data["supplements"] = supplements
+ data = trim_none(data)
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None # always summarize video
+ min_video_duration = 120
+ sources.append({"type": "video", "path": media["video"]})
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=data["title"] or url,
+ author=data["author_name"],
+ url=url,
+ date=data.get("created_at") or nowstr(TZ),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/wechat.py
@@ -1,25 +1,39 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import contextlib
+import json
+import re
from pathlib import Path
from urllib.parse import quote_plus
+from bs4 import BeautifulSoup
+from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import API, CAPTION_LENGTH, DOWNLOAD_DIR, PROXY, TEXT_LENGTH, TOKEN
+from config import AI, API, DOWNLOAD_DIR, PROXY, TOKEN, TZ
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import blockquote, count_without_entities, summay_media
+from messages.sender import send2tg, send_blockquote_texts
+from messages.utils import remove_img_tag, summay_media
from networking import download_file, download_media, hx_req
-from publish import publish_telegraph
-from utils import nowstr, rand_string
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import convert2md, nowstr, rand_string, remove_consecutive_newlines
-async def preview_wechat(client: Client, message: Message, url: str = "", db_key: str = "", **kwargs):
+async def preview_wechat(
+ client: Client,
+ message: Message,
+ url: str = "",
+ db_key: str = "",
+ *,
+ summary_wechat: bool = True,
+ summary_wechat_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
+ **kwargs,
+):
"""Preview wechat link in the message.
Args:
@@ -42,35 +56,82 @@ async def preview_wechat(client: Client, message: Message, url: str = "", db_key
if error := post_info.get("error"):
await modify_progress(text=f"❌微信链接解析失败{url}\n{error}", force_update=True, **kwargs)
return
- sent_messages = []
- length = await count_without_entities(post_info["header"] + post_info["markdown"])
- if not post_info.get("media"): # 无图片
- if length < TEXT_LENGTH - 8: # 无图片短文
- texts = f"{post_info['header']}\n{blockquote(post_info['markdown'])}"
- sent_messages.extend(await send2tg(client, message, texts=texts, **kwargs))
- else: # 无图片长文
- texts = f"{post_info['header']}"
- telegraph_url = await publish_telegraph(title=post_info["title"], html=post_info["html"], author=post_info["author"], url=url)
- if telegraph_url:
- texts += f"\n⚡️[即时预览]({telegraph_url})"
- sent_messages.extend(await send2tg(client, message, texts=texts, media=[{"document": post_info["html_path"]}], **kwargs))
- elif length < CAPTION_LENGTH - 8: # 有图片短文
- texts = f"{post_info['header']}\n{blockquote(post_info['markdown'])}"
- sent_messages.extend(await send2tg(client, message, texts=texts, media=post_info["media"], **kwargs))
- else: # 有图片长文
- texts = f"{post_info['header']}"
- telegraph_url = await publish_telegraph(title=post_info["title"], html=post_info["html"], author=post_info["author"], url=url)
- if telegraph_url:
- texts += f"\n⚡️[即时预览]({telegraph_url})"
- sent_messages.extend(await send2tg(client, message, texts=texts, media=[{"document": post_info["path"]}], **kwargs))
- kwargs["reply_msg_id"] = -1 # do not send as reply
- sent_messages.extend(await send2tg(client, message, texts=texts, media=post_info["media"], **kwargs))
+
+ # send texts first
+ text_messages = await send_blockquote_texts(client, message, texts=post_info["caption"], **kwargs)
+ text_messages = [x for x in text_messages if isinstance(x, Message)]
+ media_messages = []
+ if media := post_info.get("media"):
+ reply_to_msg: Message = glom(text_messages, "-1", default=message)
+ media_messages = await send2tg(client, reply_to_msg, media=media, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
- await save_messages(messages=sent_messages, key=db_key)
+ # Summary wechat
+ if summary_wechat and text_messages:
+ edited_msg = await summarize_wechat(text_messages[0], post_info, summary_wechat_model, url)
+ text_messages[0] = edited_msg
+ await save_messages(messages=text_messages + media_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in post_info.get("media", [])]
-async def get_wechat_info(url: str, **kwargs) -> dict:
+async def get_wechat_info(url: str, *, use_tikhub: bool = True, **kwargs) -> dict:
"""Get WeChat post info."""
+ headers = {
+ "User-Agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 18_7 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148 MicroMessenger/8.0.69(0x18004539) NetType/4G Language/zh_CN"
+ }
+ resp = await hx_req(url, headers=headers, mobile=True, proxy=PROXY.WECHAT, rformat="content")
+ try:
+ html = resp["content"].decode("utf-8")
+ soup = BeautifulSoup(html, "html.parser")
+ title_tag = soup.find("meta", property="og:title")
+ title = "微信公众号文章"
+ if title_tag and title_tag.get("content"):
+ title = str(title_tag["content"])
+
+ date = ""
+ if match_date := re.search(r"createTime = '(.*)'", html):
+ date = match_date.group(1)
+ if len(date) == 16: # '2026-06-02 09:02'
+ date += ":00"
+ date = date or nowstr(TZ)
+
+ author = ""
+ if match_author := re.search(r"nick_name: '(.*)'", html):
+ author = match_author.group(1)
+ if not author:
+ author_tag = soup.find("meta", attrs={"name": "author"})
+ if author_tag and author_tag.get("content"):
+ author = str(author_tag["content"])
+ author = author or "微信公众号"
+ # clean soup
+ ban_attrs: list[dict] = [{"style": "display:none"}, {"style": "display: none;"}, {"aria-hidden": "true"}]
+ for attr in ban_attrs:
+ for tag in soup.find_all(attrs=attr):
+ tag.decompose()
+ markdown = convert2md(html=str(soup))
+ caption = f"🟢[{author}]({url})\n🕒{date}\n📝**[{title}]({url})**\n\n"
+ markdown_no_img, image_urls = remove_img_tag(markdown)
+ caption += remove_consecutive_newlines(markdown_no_img)
+
+ # download images
+ image_urls = [url for url in image_urls if url.startswith("http")]
+ media = []
+ for img in image_urls:
+ suffix = ".jpg" if "mmbiz_jpg" in img or "wx_fmt=jpeg" in img else ".png"
+ path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}{suffix}"
+ media.append({"photo": download_file(img, path=path, proxy=PROXY.WECHAT, **kwargs)})
+ await modify_progress(text=f"✅解析成功...\n⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
+ media = await download_media(media, **kwargs)
+ except Exception as e:
+ logger.warning(f"⚠️直接解析微信文章失败: {e}")
+ if use_tikhub:
+ return await get_via_tikhub(url, **kwargs)
+ return {"error": str(e)}
+ return {"markdown": markdown, "media": media, "title": title, "author": author, "caption": caption, "date": date}
+
+
+async def get_via_tikhub(url: str, **kwargs) -> dict:
+ """Get WeChat post info via TikHub."""
api_url = API.TIKHUB_WECHAT + quote_plus(url)
logger.info(f"Preview WeChat TikHub for {api_url}")
headers = {"authorization": f"Bearer {TOKEN.TIKHUB}", "accept": "application/json"}
@@ -86,32 +147,46 @@ async def get_wechat_info(url: str, **kwargs) -> dict:
with contextlib.suppress(Exception):
dt = data["datetime"] # 2025-04-28T06:12:35.833830
dt = dt[:19].replace("T", " ") # 2025-04-28 06:12:35
- header = f"🟢[{author}]({url})\n🕒{dt}\n**📝{title}**"
media = []
- htmls = ""
- texts = ""
- markdowns = ""
+ markdown = f"🟢[{author}]({url})\n🕒{dt}\n**📝{title}**"
for tag in data["content"]["raw_content"]:
- html = ""
if text := tag.get("text", ""):
- html = f"<h3>{text}</h3>" if tag.get("type", "") == "section" else f"<p>{text}</p>"
- markdown = f"\n\n**{text}**" if tag.get("type", "") == "section" else f"\n{text}"
- text = f"\n\n{text}" if tag.get("type", "") == "section" else f"\n{text}"
- htmls += f"<br>{html}"
- markdowns += f"\n{markdown}"
- texts += f"\n{text}"
+ markdown += f"\n\n**{text}**" if tag.get("type", "") == "section" else f"\n{text}"
if images := tag.get("images", []):
for img in images:
src = img.get("src", "")
ext = img.get("type", "png")
media.append({"photo": download_file(src, path=f"{DOWNLOAD_DIR}/{rand_string()}.{ext}", proxy=PROXY.WECHAT, **kwargs)})
- htmls += f"<br><img src='{PROXY.IMG}{src}' alt='微信图片'/>"
await modify_progress(text=f"✅解析成功...\n⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
media = await download_media(media, **kwargs)
- txt_path = Path(DOWNLOAD_DIR) / f"{title}.txt"
- with txt_path.open("w") as f:
- f.write(f"📝{title}\n👤{author}\n🕒{dt}\n🔗{url}\n\n" + texts.strip())
except Exception as e:
logger.error(e)
return {"error": str(e)}
- return {"html": htmls, "path": txt_path.as_posix(), "markdown": markdowns, "media": media, "title": title, "author": author, "header": header}
+ return {"markdown": markdown, "caption": markdown, "media": media, "title": title, "author": author, "date": dt}
+
+
+async def summarize_wechat(message: Message, wechat: dict, model: str, url: str) -> Message:
+ """Generate source for AI summary."""
+ data = {
+ "platform": "微信公众号",
+ "title": wechat["title"],
+ "author_name": wechat["author"],
+ "created_at": wechat["date"],
+ "url": url,
+ "content": wechat["markdown"],
+ }
+
+ sources = [{"type": "image", "path": media["photo"]} for media in wechat.get("media", [])]
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=data["title"],
+ author=data["author_name"],
+ url=url,
+ date=data.get("created_at") or nowstr(TZ),
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/weibo.py
@@ -4,17 +4,18 @@ import contextlib
import json
import re
from datetime import datetime
+from pathlib import Path
from urllib.parse import quote_plus
from zoneinfo import ZoneInfo
from bs4 import BeautifulSoup
-from glom import glom
+from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
from bridge.social import send_to_social_media_bridge
-from config import API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
+from config import AI, API, DOWNLOAD_DIR, PROVIDER, PROXY, TELEGRAM_UA, TOKEN, TZ, cache
from cookies import get_weibo_cookies
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
@@ -23,7 +24,9 @@ from messages.sender import send2tg
from messages.utils import blockquote, summay_media
from networking import download_file, download_first_success_urls, download_media, hx_req
from others.emoji import emojify
-from utils import rand_string, readable_count, soup_to_text, split_parts, true
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import nowstr, rand_string, readable_count, soup_to_text, split_parts, true
async def preview_weibo(
@@ -35,6 +38,8 @@ async def preview_weibo(
*,
weibo_provider: str = PROVIDER.WEIBO,
weibo_comments: bool = True,
+ summary_weibo: bool = False,
+ summary_weibo_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
show_author: bool = True,
show_pubdate: bool = True,
show_ip: bool = True,
@@ -133,9 +138,23 @@ async def preview_weibo(
comments = ""
if true(weibo_comments):
comments = await parse_weibo_comments(post_id)
- sent_messages = await send2tg(client, message, texts=emojify(msg.strip()) + comments, media=media, **kwargs)
+ sent_messages = await send2tg(client, message, texts=emojify(msg.strip()) + comments, media=media, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary weibo
+ # find the first message that has a caption
+ caption_msg = None
+ index = -1
+ for idx, m in enumerate(sent_messages):
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ index = idx
+ break
+ if summary_weibo and caption_msg:
+ edited_msg = await summarize_weibo(caption_msg, this_info, quote_info, media, summary_weibo_model, url)
+ sent_messages[index] = edited_msg
await save_messages(messages=sent_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in media]
@cache.memoize(ttl=30)
@@ -342,3 +361,46 @@ def real_weibo_post_id(post_id: str) -> str:
value = base62_to_b10(post_id[: length - group * 4])
mid = str(value) + mid
return mid
+
+
+async def summarize_weibo(message: Message, this_info: dict, quote_info: dict, media_list: list[dict], model: str, url: str) -> Message:
+ """Generate source for AI summary."""
+ data = {"platform": "微博"} | this_info | {"quote_post": quote_info}
+
+ def trim(obj: dict) -> dict:
+ if isinstance(obj, dict):
+ return {k: trim(v) for k, v in obj.items() if v not in ["", None, {}]}
+ if isinstance(obj, list):
+ return [trim(item) for item in obj if item not in ["", None, {}]] # ty:ignore[invalid-return-type]
+ return obj
+
+ data = trim(data)
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None # always summarize video
+ min_video_duration = 120
+ sources.append({"type": "video", "path": media["video"]})
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ author_name = glom(data, Coalesce("author", "quote_post.author"), default="Anonymous")
+ title = glom(data, Coalesce("texts", "quote_post.texts"), default="微博")
+ created_at = glom(data, Coalesce("dt", "quote_post.dt"), default=None)
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=f"🧣{title}",
+ author=author_name,
+ url=url,
+ date=created_at or nowstr(TZ),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/preview/xiaohongshu.py
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import json
from datetime import datetime
+from pathlib import Path
from zoneinfo import ZoneInfo
import yaml
@@ -10,8 +12,9 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from ai.utils import trim_none
from bridge.social import send_to_social_media_bridge
-from config import PROVIDER, PROXY, TZ
+from config import AI, PROVIDER, PROXY, TZ
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
from messages.progress import modify_progress
@@ -19,7 +22,9 @@ from messages.sender import send2tg
from messages.utils import summay_media
from networking import download_file, download_first_success_urls, download_media, hx_req
from others.emoji import emojify
-from utils import true
+from preview.utils import add_summary_url
+from summarize.summarize import summarize
+from utils import nowstr, true
async def preview_xhs(
@@ -31,6 +36,8 @@ async def preview_xhs(
*,
is_xhs_link: bool = False,
xhs_provider: str = PROVIDER.XHS,
+ summary_xhs: bool = False,
+ summary_xhs_model: str = AI.AI_SUMMARY_MODEL_ALIAS,
show_author: bool = True,
show_title: bool = True,
show_pubdate: bool = True,
@@ -78,6 +85,7 @@ async def preview_xhs(
await modify_progress(text="❌小红书解析失败, 请稍候再尝试", force_update=True, **kwargs)
return
await modify_progress(text="✅解析成功, 正在处理...", **kwargs)
+ note["url"] = url
media: list[dict] = []
if note.get("type") == "video":
video_urls = [] # Extract all urls, but prefer H264
@@ -107,16 +115,16 @@ async def preview_xhs(
texts = ""
if true(show_author) and (author := glom(note, Coalesce("user.nickname", "user.nickName"), default="")):
- texts += f"🍠[{author}]({url})\n"
+ texts += f"🍠**[{author}]({url})**\n"
if true(show_pubdate) and note.get("time"):
dt = datetime.fromtimestamp(float(note["time"]) / 1000).astimezone(ZoneInfo(TZ))
- texts += f"🕒{dt:%Y-%m-%d %H:%M:%S}"
+ texts += f"🕒{dt:%Y-%m-%d %H:%M:%S}\n"
+ if true(show_statistics) and xhs_info.get("statistics"):
+ texts += f"{xhs_info['statistics']}"
if true(show_ip) and note.get("ipLocation"):
texts += f"📍{note['ipLocation']}\n"
else:
texts += "\n"
- if true(show_statistics) and xhs_info.get("statistics"):
- texts += f"{xhs_info['statistics']}\n"
if true(show_title) and note.get("title"):
texts += f"📝**{note['title']}**\n"
desc = note.get("desc", "").replace("[话题]#", "")
@@ -125,9 +133,23 @@ async def preview_xhs(
comments = get_xhs_comments(xhs_info.get("soup")) # Not implemented yet
await modify_progress(text=f"⏬正在下载:\n{summay_media(media)}", force_update=True, **kwargs)
media = await download_media(media, **kwargs)
- sent_messages = await send2tg(client, message, texts=emojify(texts), media=media, comments=comments, **kwargs)
+ sent_messages = await send2tg(client, message, texts=emojify(texts), media=media, comments=comments, keep_file=True, **kwargs)
await modify_progress(del_status=True, **kwargs)
+ # Summary xhs
+ # find the first message that has a caption
+ caption_msg = None
+ index = -1
+ for idx, m in enumerate(sent_messages):
+ if isinstance(m, Message) and (m.caption or m.text):
+ caption_msg = m
+ index = idx
+ break
+ if summary_xhs and caption_msg:
+ edited_msg = await summarize_xhs(caption_msg, note, media, summary_xhs_model)
+ sent_messages[index] = edited_msg
await save_messages(messages=sent_messages, key=db_key)
+ # Clean up
+ [Path(glom(x, Coalesce("photo", "video", "audio"))).unlink(missing_ok=True) for x in media]
async def get_xhs_info(url: str, retry: int = 0, *, use_mobile: bool = False) -> dict:
@@ -184,3 +206,45 @@ def get_xhs_comments(soup: BeautifulSoup | None) -> list[str]:
if not soup:
return []
return []
+
+
+async def summarize_xhs(message: Message, note: dict, media_list: list[dict], model: str) -> Message:
+ """Generate source for AI summary."""
+ data = {
+ "platform": "小红书",
+ "title": note.get("title"),
+ "author_name": glom(note, Coalesce("user.nickname", "user.nickName"), default=None),
+ "url": note["url"],
+ "location": note.get("ipLocation"),
+ }
+ if desc := note.get("desc", "").replace("[话题]#", ""):
+ data["description"] = desc
+ data = trim_none(data)
+ sources = []
+ min_text_length = 1000 # skip short tweets
+ min_video_duration = None
+ for media in media_list:
+ if media.get("photo"):
+ sources.append({"type": "image", "path": media["photo"]})
+ if media.get("video"):
+ min_text_length = None # always summarize video
+ min_video_duration = 120
+ sources.append({"type": "video", "path": media["video"]})
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ author_name = data.get("author", "Anonymous")
+ pid = note.get("noteId", "小红书")
+ summary = await summarize(
+ sources=sources,
+ model=model,
+ title=f"🍠{author_name} - {pid}",
+ author=author_name,
+ url=data["url"],
+ date=data.get("created_at") or nowstr(TZ),
+ min_text_length=min_text_length,
+ min_video_duration=min_video_duration,
+ max_video_duration=3600, # skip long videos more than 1 hour
+ )
+ telegraph_url = summary.get("telegraph_url")
+ if not telegraph_url:
+ return message
+ return await add_summary_url(telegraph_url, message) or message
src/subtitles/subtitle.py
@@ -7,7 +7,6 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import InputMediaDocument, LinkPreviewOptions, Message
-from ai.summary import summarize
from asr.voice_recognition import asr_file
from config import AI, ASR, DOWNLOAD_DIR, PREFIX, READING_SPEED, TEXT_LENGTH, cache
from messages.parser import parse_msg
@@ -18,6 +17,7 @@ from networking import match_social_media_link
from preview.bilibili import get_bilibili_vinfo
from preview.youtube import get_youtube_vinfo
from subtitles.base import fetch_subtitle, match_url
+from summarize.summarize import summarize
from utils import count_subtitles, readable_time
from ytdlp.download import ytdlp_download
@@ -40,7 +40,7 @@ async def get_subtitle(
message: Message,
*,
ai_summary: bool = True,
- summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
+ summary_subtitle_model: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
enable_corrector: bool = False,
**kwargs,
):
@@ -70,7 +70,7 @@ async def get_subtitle(
description = glom(vinfo, Coalesce("description", "desc"), default="")
caption = f"{vinfo['emoji']}[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['pubdate']}\n📝[{vinfo['title']}]({url})"
msg = f"🔍**正在获取字幕:**\n{caption}"[:TEXT_LENGTH]
- status_msg: Message = (await send2tg(client, message, texts=msg, **kwargs))[0] # ty:ignore[invalid-assignment]
+ status_msg: Message = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
this_info = parse_msg(message, silent=True)
@@ -120,7 +120,7 @@ async def get_subtitle(
full = glom(res, Coalesce("full", "subtitles", "summary"), default="")
# Send subtitle txt
with BytesIO(full.encode("utf-8")) as f:
- status_msg = await status_msg.edit_media(file_name=f"{vinfo['title']}.txt", media=InputMediaDocument(f, caption=caption))
+ status_msg = await status_msg.edit_media(media=InputMediaDocument(f, caption=caption))
if ai_summary and isinstance(status_msg, Message):
# use real subtitle (without AI summary by Bilibili)
@@ -129,20 +129,21 @@ async def get_subtitle(
if description.strip():
prompt += f"节目简介: {description}"
summary = await summarize(
- transcripts=subtitles,
- reference=prompt,
- model=summary_model_id,
+ sources=[{"type": "system_prompt", "text": prompt}, {"type": "transcripts", "text": subtitles}],
+ model=summary_subtitle_model,
title=vinfo["title"],
author=vinfo["author"],
url=url,
date=vinfo["pubdate"],
description=description,
+ min_text_length=200,
+ force_r2_page=kwargs.get("force_r2_page", False),
)
if not summary.get("texts"):
return
telegraph_url = summary.get("telegraph_url") or ""
- link_preview = LinkPreviewOptions(is_disabled=False, url=telegraph_url) if telegraph_url else LinkPreviewOptions(is_disabled=True)
+ link_preview = LinkPreviewOptions(is_disabled=False, show_above_text=True, url=telegraph_url) if telegraph_url else LinkPreviewOptions(is_disabled=True)
if await count_without_entities(summary["texts"]) <= TEXT_LENGTH:
await status_msg.reply_text(blockquote(summary["texts"]), quote=True, link_preview_options=link_preview)
elif telegraph_url:
src/summarize/main.py
@@ -0,0 +1,215 @@
+#!/venv/bin/python
+# -*- coding: utf-8 -*-
+import asyncio
+import json
+import warnings
+from pathlib import Path
+from typing import Any
+
+from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
+from glom import Coalesce, glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from ai.texts.contexts import MARKDOWN_EXT, TXT_EXT, full_chain_contexts, is_multi_user_chat, message_bytes
+from config import AI, ASR, PREFIX, PROXY, TZ
+from database.r2 import get_cf_r2
+from messages.database import copy_messages_from_db
+from messages.help import social_media_help
+from messages.parser import parse_msg
+from messages.sender import send2tg, send_blockquote_texts
+from messages.utils import delete_message, equal_prefix, set_reaction, startswith_prefix
+from networking import match_social_media_link
+from others.download_external import AUDIO_FORMAT, VIDEO_FORMAT
+from preview.arxiv import preview_arxiv
+from preview.bilibili import make_bvid_clickable, preview_bilibili
+from preview.douyin import preview_douyin
+from preview.instagram import preview_instagram
+from preview.twitter import preview_twitter
+from preview.v2ex import preview_v2ex
+from preview.wechat import preview_wechat
+from preview.weibo import preview_weibo
+from preview.xiaohongshu import preview_xhs
+from summarize.summarize import summarize
+from utils import convert2md, nowstr, read_text, soup_to_text, ts_to_dt
+from ytdlp.download import ytdlp_download
+from ytdlp.utils import ProxyError, get_subtitles
+
+
+# ruff: noqa: RET502,RET503
+async def ai_summary(client: Client, message: Message, summary_model_id: str = AI.AI_SUMMARY_MODEL_ALIAS, *, mermaid: bool = False, **kwargs) -> Any:
+ if not startswith_prefix(message.content, prefix=PREFIX.AI_SUMMARY):
+ return
+ this_msg = message
+ if equal_prefix(message.content, PREFIX.AI_SUMMARY):
+ if not message.reply_to_message:
+ return await send2tg(client, message, texts=social_media_help(message), **kwargs)
+ message = message.reply_to_message
+
+ chains = await full_chain_contexts(client, message, order="asc") # old to new
+ file_bytes = sum(message_bytes(m) for m in chains)
+ if file_bytes > 512 * 1024 * 1024:
+ logger.warning(f"file_bytes: {file_bytes} > 512MB, skip")
+ await this_msg.reply_text("❌上下文大小超过512MB,不支持总结")
+ await asyncio.sleep(5)
+ await delete_message(message)
+ return
+ await set_reaction(client, this_msg, "👌")
+ matched = await match_social_media_link(str(message.content))
+ kwargs |= {
+ "summary_twitter": True,
+ "summary_douyin": True,
+ "summary_xhs": True,
+ "summary_weibo": True,
+ "summary_wechat": True,
+ "summary_instagram": True,
+ "summary_v2ex": True,
+ "summary_ytdlp": True,
+ "enable_corrector": False,
+ } | matched
+ if matched["platform"] == "arxiv":
+ return await preview_arxiv(client, message, **kwargs)
+ if matched["platform"] in ["douyin", "tiktok"]: # noqa: RET505
+ return await preview_douyin(client, message, **kwargs)
+ if matched["platform"] == "instagram":
+ return await preview_instagram(client, message, **kwargs)
+ if matched["platform"] in ["x", "twitter", "fxtwitter", "fixupx"]:
+ return await preview_twitter(client, message, **kwargs)
+ if matched["platform"] == "weibo":
+ return await preview_weibo(client, message, **kwargs)
+ if matched["platform"] == "xiaohongshu":
+ return await preview_xhs(client, message, **kwargs)
+ if matched["platform"] == "wechat":
+ return await preview_wechat(client, message, **kwargs)
+ if matched["platform"] == "v2ex":
+ return await preview_v2ex(client, message, **kwargs)
+ if matched["platform"].startswith("bilibili-"): # this is not bilibili video, for videos, use yt-dlp
+ return await preview_bilibili(client, message, **kwargs)
+
+ sources = await get_sources(client, chains)
+ info = {}
+ if matched["platform"] in ["bilibili", "youtube", "ytdlp"]:
+ r2 = await get_cf_r2(matched["db_key"])
+ if "🤖AI导读" in "".join(glom(r2, "data.*.text", default=[])) and await copy_messages_from_db(client, message, key=matched["db_key"], kv=r2, **kwargs):
+ await set_reaction(client, this_msg, "🎉")
+ return
+ if info := await download_ytdlp(**kwargs):
+ sources.extend(info.get("sources", []))
+
+ logger.debug(f"Summary sources: {sources}")
+ summary = await summarize(
+ sources=sources,
+ model=summary_model_id,
+ title=info.get("title") or "AI导读",
+ author=info.get("author") or "Anonymous",
+ url=matched.get("url"),
+ date=info.get("created_at") or nowstr(TZ),
+ description=info.get("description"),
+ force_r2_page=bool(kwargs.get("force_r2_page")),
+ mermaid=mermaid,
+ )
+ if summary.get("texts"):
+ await send_blockquote_texts(client, message, texts=summary["texts"], **kwargs)
+ await set_reaction(client, this_msg, "🎉")
+ return
+ await set_reaction(client, this_msg, "💔")
+
+
+async def get_sources(client: Client, chains: list[Message]) -> list[dict]:
+ sources = []
+ add_sender = is_multi_user_chat(chains)
+ for msg in chains:
+ info = parse_msg(msg, silent=True, use_cache=False)
+ meta: dict = {"message_sender": info["full_name"]} if add_sender else {}
+
+ if msg.audio or msg.photo or msg.video or msg.document:
+ fpath: str = await client.download_media(msg) # ty:ignore[invalid-assignment]
+ if not Path(fpath).is_file():
+ continue
+ if msg.photo:
+ sources.append({"type": "image", "path": fpath})
+ elif msg.video:
+ sources.append({"type": "video", "path": fpath, "mime_type": msg.video.mime_type})
+ elif msg.audio:
+ sources.append({"type": "audio", "path": fpath, "mime_type": msg.audio.mime_type})
+ elif msg.document:
+ mime = glom(msg, "document.mime_type", default="") or ""
+ fname = glom(msg, "document.file_name", default="") or ""
+ if mime.startswith("image/"):
+ sources.append({"type": "image", "path": fpath, "mime_type": mime})
+ elif mime.startswith("audio/") or Path(fname).suffix in AUDIO_FORMAT:
+ sources.append({"type": "audio", "path": fpath, "mime_type": mime})
+ elif mime.startswith("video/") or Path(fname).suffix in VIDEO_FORMAT:
+ sources.append({"type": "video", "path": fpath, "mime_type": mime})
+ elif mime.startswith("text/") or Path(fname).suffix in TXT_EXT:
+ txt = {"file_name": fname, "file_content": read_text(fpath)}
+ sources.append({"type": "text", "text": json.dumps(meta | txt, ensure_ascii=False)})
+ elif Path(fname).suffix in MARKDOWN_EXT:
+ txt = {"file_name": fname, "file_content": convert2md(path=fpath)}
+ sources.append({"type": "text", "text": json.dumps(meta | txt, ensure_ascii=False)})
+ else:
+ sources.append({"type": "file", "path": fpath, "mime_type": mime})
+ if txt := glom(msg, Coalesce("content.html", "content", "text", "caption"), default=""):
+ texts = json.dumps(meta | {"message": txt}, ensure_ascii=False) if add_sender else txt
+ sources.append({"type": "text", "text": texts})
+ matched = await match_social_media_link(txt)
+ if matched["platform"] == "youtube":
+ sources.append({"type": "youtube", "url": matched["url"]})
+ return sources
+
+
+async def download_ytdlp(url: str, **kwargs) -> dict:
+ kwargs |= {"ytdlp_download_video": True, "show_progress": False}
+ try:
+ resp = await ytdlp_download(url, **kwargs)
+ if resp["video_path"].is_file():
+ return await ytdlp_info(resp, url, kwargs["platform"])
+ except ProxyError:
+ logger.error(f"🚫{kwargs['platform']}代理错误")
+ if PROXY.YTDLP_FALLBACK:
+ logger.warning(f"🔄使用备用代理{PROXY.YTDLP_FALLBACK}")
+ kwargs |= {"proxy": PROXY.YTDLP_FALLBACK}
+ return await download_ytdlp(url, **kwargs)
+ return {}
+
+
+async def ytdlp_info(info: dict, url: str, platform: str) -> dict:
+ data = {
+ "platform": platform.title(),
+ "author": info.get("author") or "Anonymous",
+ "title": info.get("title") or platform.title(),
+ "url": url,
+ }
+ sources = []
+ video = info["video_path"]
+ audio = info["audio_path"]
+ asr_path = audio if audio.is_file() else video
+ if video.is_file():
+ sources.append({"type": "video", "path": video.as_posix()})
+
+ if not video.is_file() and info["audio_path"].is_file():
+ sources.append({"type": "audio", "path": audio.as_posix()})
+
+ if subtitles := await get_subtitles(asr_path, url, asr_engine=ASR.DEFAULT_ENGINE, vinfo=info, enable_corrector=False):
+ sources.append({"type": "transcripts", "text": subtitles})
+
+ # date
+ if info.get("pubdate"):
+ data["created_at"] = info["pubdate"].removeprefix("🕒")
+ elif dt := ts_to_dt(info.get("timestamp")):
+ data["created_at"] = f"{dt:%Y-%m-%d %H:%M:%S}"
+ elif info.get("upload_date"):
+ data["created_at"] = info["update_date"]
+ else:
+ data["created_at"] = nowstr(TZ)
+
+ # desc
+ if (desc := info.get("description")) and (desc != "-"):
+ warnings.simplefilter("ignore", MarkupResemblesLocatorWarning)
+ soup = BeautifulSoup(desc, "html.parser")
+ desc_text = soup_to_text(soup)
+ data["description"] = make_bvid_clickable(desc_text)
+ sources.append({"type": "text", "text": json.dumps(data, ensure_ascii=False)})
+ data["sources"] = sources
+ return data
src/summarize/summarize.py
@@ -0,0 +1,310 @@
+#!/venv/bin/python
+# -*- coding: utf-8 -*-
+from datetime import datetime
+from pathlib import Path
+
+from loguru import logger
+from pyrogram.types import Chat, Message
+from pyrogram.types.messages_and_media.message import Str
+
+from ai.main import ai_text_generation
+from asr.utils import audio_duration as get_media_duration
+from config import PREFIX
+from publish import telegraph_aipage
+from schema import AIPage, ContentExtraction
+from summarize.utils import generate_mermaid, parse_summary_sources, publish_mermaid
+from utils import count_subtitles, digest, read_text, to_dt
+
+
+async def summarize(
+ sources: list[dict] | None = None,
+ model: str = "gemini",
+ title: str | None = None,
+ author: str | None = None,
+ url: str | None = None,
+ date: str | datetime | None = None,
+ description: str | dict | None = None,
+ ttl: str | None = None,
+ *,
+ force_r2_page: bool = False,
+ mermaid: bool = False,
+ min_text_length: int | None = None, # minimum text length to summarize
+ min_audio_duration: float | None = None, # minimum audio duration to summarize
+ min_video_duration: float | None = None, # minimum video duration to summarize
+ max_audio_duration: float | None = None, # maximum audio duration to summarize
+ max_video_duration: float | None = None, # maximum video duration to summarize
+ max_audio_bytes: int | None = None, # maximum audio bytes to summarize
+ max_video_bytes: int | None = None, # maximum video bytes to summarize
+ min_num_image: int | None = None, # minimum number of images to summarize
+ max_num_image: int | None = None, # maximum number of images to summarize
+ min_num_video: int | None = None, # minimum number of videos to summarize
+ max_num_video: int | None = None, # maximum number of videos to summarize
+ min_num_audio: int | None = None, # minimum number of audios to summarize
+ max_num_audio: int | None = None, # maximum number of audios to summarize
+ skip_max_video_duration: float | None = None, # skip max video duration if it is greater than this value
+ skip_max_audio_duration: float | None = None, # skip max audio duration if it is greater than this value
+ skip_max_video_bytes: int | None = None, # skip max video bytes if it is greater than this value
+ skip_max_audio_bytes: int | None = None, # skip max audio bytes if it is greater than this value
+) -> dict:
+ r"""Summarize the article or transcripts.
+
+ Args:
+ sources (list[dict] | None): The sources to summary.
+ # text
+ {"type": "system_prompt", "text": "This is system prompt."}
+ {"type": "text", "text": "Hello."}
+ {"type": "text", "path": "/path/to/file.txt"}
+ {"type": "transcripts", "text": "[00:00] Hello\n[00:01] a sentence"}
+
+ # media
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
+
+ # file
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
+
+ # special
+ {"type": "youtube", "url": "https://www.youtube.com/watch?v=videoid"}
+
+ Returns:
+ The summary.
+ """
+ sources = filter_sources(
+ sources,
+ skip_max_video_duration=skip_max_video_duration,
+ skip_max_audio_duration=skip_max_audio_duration,
+ skip_max_video_bytes=skip_max_video_bytes,
+ skip_max_audio_bytes=skip_max_audio_bytes,
+ )
+ if not sources:
+ return {}
+ title = title or "AI导读"
+ if not is_eligible(
+ sources,
+ min_text_length=min_text_length,
+ min_audio_duration=min_audio_duration,
+ min_video_duration=min_video_duration,
+ max_audio_duration=max_audio_duration,
+ max_video_duration=max_video_duration,
+ max_audio_bytes=max_audio_bytes,
+ max_video_bytes=max_video_bytes,
+ min_num_image=min_num_image,
+ max_num_image=max_num_image,
+ min_num_video=min_num_video,
+ max_num_video=max_num_video,
+ min_num_audio=min_num_audio,
+ max_num_audio=max_num_audio,
+ ):
+ return {}
+ texts, transcripts, schema = parse_summary_sources(sources, mermaid=mermaid)
+ checksum = int(digest(sources, to_int=True))
+ ai_msg = Message(id=checksum, chat=Chat(id=checksum), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{model} {texts or 'Summarize'}"))
+ summary = await ai_text_generation("fake", message=ai_msg, **schema) # type: ignore
+ if not summary.get("texts", ""):
+ return {}
+ texts, mermaid_img_url, mermaid_pako_url = await parse_summary(summary["texts"])
+ page = AIPage(
+ title=title,
+ author=author,
+ url=url,
+ date=to_dt(date),
+ description=description,
+ summary=ContentExtraction.model_validate_json(summary["texts"]),
+ transcripts=transcripts,
+ mermaid_img=mermaid_img_url,
+ mermaid_url=mermaid_pako_url,
+ )
+ if telegraph_url := await telegraph_aipage(page, ttl=ttl, force_r2=force_r2_page):
+ summary["telegraph_url"] = telegraph_url
+ summary["texts"] = f"**🤖[AI导读]({telegraph_url})**\n" + texts
+ else:
+ summary["telegraph_url"] = None
+ summary["texts"] = "**🤖AI导读**\n" + texts
+ return summary
+
+
+async def parse_summary(texts: str) -> tuple[str, str, str]:
+ """Parse the summary JSON string.
+
+ Returns:
+ (summary_texts, mermaid_img_url, mermaid_pako_url)
+ """
+ img_url = ""
+ pako_url = ""
+ try:
+ summary = ContentExtraction.model_validate_json(texts)
+ if summary.mindmap:
+ mermaid = generate_mermaid(summary.mindmap)
+ img_url, pako_url = await publish_mermaid(mermaid)
+ parsed = f"{summary.overview}\n⚡️**章节速览**"
+ for section in summary.sections:
+ parsed += f"\n{section.emoji}**{section.title}**"
+ if section.start:
+ start = section.start.removeprefix("00:") if len(section.start) > 5 else section.start
+ parsed += f" [{start}]"
+ parsed += f"\n{section.content}"
+ logger.success(parsed)
+ except Exception as e:
+ logger.error(f"Error parsing summary: {e}")
+ return texts, "", ""
+ return parsed, img_url, pako_url
+
+
+def filter_sources(
+ sources: list[dict] | None,
+ skip_max_video_duration: float | None = None, # skip max video duration if it is greater than this value
+ skip_max_audio_duration: float | None = None, # skip max audio duration if it is greater than this value
+ skip_max_video_bytes: int | None = None, # skip max video bytes if it is greater than this value
+ skip_max_audio_bytes: int | None = None, # skip max audio bytes if it is greater than this value
+) -> list[dict]:
+ """Filter the sources by the given conditions.
+
+ Returns:
+ The filtered sources.
+ """
+ if not sources:
+ return []
+ filtered = []
+ for source in sources:
+ if skip_max_video_duration is not None and source["type"] == "video" and Path(source["path"]).is_file():
+ duration = get_media_duration(source["path"])
+ size = Path(source["path"]).stat().st_size
+ if duration > skip_max_video_duration:
+ logger.warning(f"Skip video {source['path']} due to duration {duration} > {skip_max_video_duration}")
+ continue
+ if isinstance(skip_max_video_bytes, int) and size > skip_max_video_bytes:
+ logger.warning(f"Skip video {source['path']} due to size {size} > {skip_max_video_bytes}")
+ continue
+ elif skip_max_audio_duration is not None and source["type"] == "audio" and Path(source["path"]).is_file():
+ duration = get_media_duration(source["path"])
+ size = Path(source["path"]).stat().st_size
+ if duration > skip_max_audio_duration:
+ logger.warning(f"Skip audio {source['path']} due to duration {duration} > {skip_max_audio_duration}")
+ continue
+ if isinstance(skip_max_audio_bytes, int) and size > skip_max_audio_bytes:
+ logger.warning(f"Skip audio {source['path']} due to size {size} > {skip_max_audio_bytes}")
+ continue
+ filtered.append(source)
+ return filtered
+
+
+def is_eligible(
+ sources: list[dict],
+ *,
+ min_text_length: int | None = None, # minimum text length to summarize
+ min_audio_duration: float | None = None, # minimum audio duration to summarize
+ min_video_duration: float | None = None, # minimum video duration to summarize
+ max_audio_duration: float | None = None, # maximum audio duration to summarize
+ max_video_duration: float | None = None, # maximum video duration to summarize
+ max_audio_bytes: int | None = None, # maximum audio bytes to summarize
+ max_video_bytes: int | None = None, # maximum video bytes to summarize
+ min_num_image: int | None = None, # minimum number of images to summarize
+ max_num_image: int | None = None, # maximum number of images to summarize
+ min_num_video: int | None = None, # minimum number of videos to summarize
+ max_num_video: int | None = None, # maximum number of videos to summarize
+ min_num_audio: int | None = None, # minimum number of audios to summarize
+ max_num_audio: int | None = None, # maximum number of audios to summarize
+) -> bool:
+ r"""Check if the source is eligible for summarization.
+
+ Args:
+ sources (list[dict] | None): The sources to summary.
+ # text
+ {"type": "system_prompt", "text": "This is system prompt."}
+ {"type": "text", "text": "Hello."}
+ {"type": "text", "path": "/path/to/file.txt"}
+ {"type": "transcripts", "text": "[00:00] Hello\n[00:01] a sentence"}
+
+ # media
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type (optional)": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type (optional)": "video/mp4", "duration (optional)": 10.0, "size (optional)": 9999}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type (optional)": "audio/mpeg", "duration (optional)": 10.0, "size (optional)": 9999}
+
+ # file
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type (optional)": "application/pdf"}
+
+ """
+ text_length = 0
+ audio_duration = 0
+ video_duration = 0
+ audio_bytes = 0
+ video_bytes = 0
+ num_image = 0
+ num_video = 0
+ num_audio = 0
+ # check text length
+ if isinstance(min_text_length, int):
+ for source in sources:
+ if source["type"] in ["text", "transcripts"] and source.get("text"):
+ text_length += count_subtitles(source["text"])
+ elif source["type"] == "text" and source.get("path"):
+ text_length += len(read_text(source["path"]))
+ if text_length < int(min_text_length):
+ logger.warning(f"Text length is too short: {text_length} < {min_text_length}")
+ return False
+
+ # check duration
+ if any(x is not None for x in [min_audio_duration, max_audio_duration, min_video_duration, max_video_duration]):
+ for source in sources:
+ if source["type"] == "audio" and Path(source["path"]).is_file():
+ audio_duration += get_media_duration(source["path"])
+ elif source["type"] == "video" and Path(source["path"]).is_file():
+ video_duration += get_media_duration(source["path"])
+
+ if min_video_duration is not None and video_duration < min_video_duration:
+ logger.warning(f"Video duration is too short: {video_duration} < {min_video_duration}")
+ return False
+ if max_video_duration is not None and video_duration > max_video_duration:
+ logger.warning(f"Video duration is too long: {video_duration} > {max_video_duration}")
+ return False
+ if min_audio_duration is not None and audio_duration < min_audio_duration:
+ logger.warning(f"Audio duration is too short: {audio_duration} < {min_audio_duration}")
+ return False
+ if max_audio_duration is not None and audio_duration > max_audio_duration:
+ logger.warning(f"Audio duration is too long: {audio_duration} > {max_audio_duration}")
+ return False
+
+ # check size
+ if isinstance(max_audio_bytes, int) or isinstance(max_video_bytes, int):
+ for source in sources:
+ if source["type"] == "audio" and Path(source["path"]).is_file():
+ audio_bytes += Path(source["path"]).stat().st_size
+ elif source["type"] == "video" and Path(source["path"]).is_file():
+ video_bytes += Path(source["path"]).stat().st_size
+
+ if isinstance(max_audio_bytes, int) and audio_bytes > max_audio_bytes:
+ logger.warning(f"Audio bytes is too large: {audio_bytes} > {max_audio_bytes}")
+ return False
+ if isinstance(max_video_bytes, int) and video_bytes > max_video_bytes:
+ logger.warning(f"Video bytes is too large: {video_bytes} > {max_video_bytes}")
+ return False
+
+ # check number of images, videos, and audios
+ if any(x is not None for x in [min_num_image, max_num_image, min_num_video, max_num_video, min_num_audio, max_num_audio]):
+ for source in sources:
+ if source["type"] == "image" and Path(source["path"]).is_file():
+ num_image += 1
+ elif source["type"] == "video" and Path(source["path"]).is_file():
+ num_video += 1
+ elif source["type"] == "audio" and Path(source["path"]).is_file():
+ num_audio += 1
+ if isinstance(min_num_image, int) and num_image < min_num_image:
+ logger.warning(f"Too few images to summarize: {num_image} < {min_num_image}")
+ return False
+ if isinstance(max_num_image, int) and num_image > max_num_image:
+ logger.warning(f"Too many images to summarize: {num_image} > {max_num_image}")
+ return False
+ if isinstance(min_num_video, int) and num_video < min_num_video:
+ logger.warning(f"Too few videos to summarize: {num_video} < {min_num_video}")
+ return False
+ if isinstance(max_num_video, int) and num_video > max_num_video:
+ logger.warning(f"Too many videos to summarize: {num_video} > {max_num_video}")
+ return False
+ if isinstance(min_num_audio, int) and num_audio < min_num_audio:
+ logger.warning(f"Too few audios to summarize: {num_audio} < {min_num_audio}")
+ return False
+ if isinstance(max_num_audio, int) and num_audio > max_num_audio:
+ logger.warning(f"Too many audios to summarize: {num_audio} > {max_num_audio}")
+ return False
+ return True
src/summarize/utils.py
@@ -0,0 +1,174 @@
+#!/venv/bin/python
+# -*- coding: utf-8 -*-
+import base64
+import json
+import zlib
+from pathlib import Path
+
+from config import DB, DOWNLOAD_DIR
+from database.r2 import set_cf_r2
+from networking import download_file, shorten_url
+from schema import MindMap, get_schema
+from utils import digest, read_text
+
+
+def system_prompt(sys: str | None = None) -> str:
+ prompt = """你是一位专业的内容提炼大师,任务是基于用户提供的资料,生成用户无需阅读完整原文档就能清晰理解主要事件、观点、结论的内容,生成符合指定JSON格式的全文总结、分片内容和思维导图。
+
+## 核心规则
+1. 语言要求:无论原资料使用何种语言(中文、英文或其他语言),输出的所有内容**需以简体中文为主**,包括JSON结构中的文本、总结、分片内容及思维导图节点;若资料中存在特定领域的专业术语(如技术、学术等领域的外文术语),可保留原外文术语,无需强制译为中文,避免强行翻译导致信息失真。
+2. 信息忠实性:提炼内容需完全忠实于原资料,不得添加个人观点、推测或无关信息。
+3. 广告过滤规则:若资料中包含与主内容完全无关的广告(如播客/B站视频的植入广告、商业推广等,特征为:内容独立于节目主题、无实质信息关联、去掉后不影响对主内容的理解),需直接忽略该部分内容,不得将广告信息纳入任何提炼结果中。
+"""
+ if sys:
+ prompt += f"\n{sys}"
+ return prompt.strip()
+
+
+def parse_summary_sources(sources: list[dict], *, mermaid: bool = False) -> tuple[str, str, dict]:
+ r"""Parse the sources to texts, transcripts, schema.
+
+ Args:
+ sources (list[dict] | None): The sources to summary.
+ # text
+ {"type": "system_prompt", "text": "This is system prompt."}
+ {"type": "text", "text": "Hello."}
+ {"type": "text", "path": "/path/to/file.txt"}
+ {"type": "transcripts", "text": "[00:00] Hello\n[00:01] a sentence"}
+
+ # media
+ {"type": "image", "path": "/path/to/image.jpg", "mime_type": "image/jpeg"}
+ {"type": "video", "path": "/path/to/video.mp4", "mime_type": "video/mp4"}
+ {"type": "audio", "path": "/path/to/audio.mp3", "mime_type": "audio/mpeg"}
+
+ # file
+ {"type": "file", "path": "/path/to/file.pdf", "mime_type": "application/pdf"}
+
+ # special
+ {"type": "youtube", "url": "https://www.youtube.com/watch?v=videoid"}
+ """
+ sys_items = [x.get("text", "") for x in sources if x["type"] == "system_prompt"]
+ sys_prompt = system_prompt("\n".join(sys_items))
+ if not mermaid:
+ sys_prompt = sys_prompt.replace("和思维导图", "").replace("及思维导图节点", "")
+ transcripts = next((x.get("text", "") for x in sources if x["type"] == "transcripts"), "")
+ texts = ""
+ for source in sources:
+ if source["type"] in ["transcripts", "text"] and source.get("text"):
+ texts += source["text"] + "\n"
+ elif source["type"] == "text" and source.get("path"):
+ texts += read_text(source["path"]) + "\n"
+ media_items = [x for x in sources if x["type"] in ["image", "video", "audio", "file", "youtube"]]
+ schema = get_schema("content_extraction")
+ if not mermaid:
+ schema["properties"].pop("mindmap")
+ if "mindmap" in schema["properties"] and "mindmap" not in schema["required"]:
+ schema["required"].append("mindmap")
+ return (
+ texts.strip(),
+ transcripts.strip(),
+ {
+ "gemini_generate_content_config": {"system_instruction": sys_prompt, "responseMimeType": "application/json", "responseJsonSchema": schema},
+ "openai_responses_config": {
+ "instructions": sys_prompt,
+ "text": {
+ "format": {
+ "type": "json_schema",
+ "name": "ContentExtraction",
+ "strict": True,
+ "description": "精准提炼资料的核心主题、关键观点、主要结论及各片段核心内容,确保输出内容全面覆盖资料的关键信息,用户仅通过总结即可掌握信息全貌。",
+ "schema": schema,
+ }
+ },
+ },
+ "openai_system_prompt": sys_prompt,
+ "openai_completions_config": {
+ "response_format": {
+ "type": "json_schema",
+ "strict": True,
+ "json_schema": {
+ "name": "ContentExtraction",
+ "schema": schema,
+ "strict": True,
+ },
+ }
+ },
+ "additional_contexts": media_items,
+ "gemini_append_grounding": False,
+ "openai_enable_tool_call": False,
+ "openai_append_tool_results": False,
+ "silent": True,
+ },
+ )
+
+
+def generate_mermaid(mindmap: MindMap) -> str:
+ """Generate Mermaid code from MindMap.
+
+ Returns:
+ Mermaid code.
+
+ Example:
+ graph LR
+ A[mindmap.main_title] --> B[topic.title]
+ A --> C[topic.title]
+
+ B --> B1[subtopic.title]
+ B --> B2[subtopic.title]
+ B --> B3[subtopic.title]
+
+ C --> C1[subtopic.title]
+ C --> C2[subtopic.title]
+
+ C1 --> C11[subtopic.leaf]
+ C1 --> C12[subtopic.leaf]
+ """
+ letter = lambda n: chr(n + 66) # Convert integer to uppercase letter ( 0 -> B )
+ quote = lambda s: '"' + s.replace('"', """) + '"'
+
+ mermaid = "---\nconfig:\n theme: neo\n look: neo\n---\ngraph LR\n"
+ indent = " " # four spaces
+ for idx_ch, topic in enumerate(mindmap.topics):
+ if idx_ch == 0:
+ mermaid += indent + f"A[{quote(mindmap.main_title)}] --> {letter(idx_ch)}[{quote(topic.title)}]\n"
+ else:
+ mermaid += indent + f"A --> {letter(idx_ch)}[{quote(topic.title)}]\n"
+
+ # topic leafs
+ for idx_topicleaf, topic_leaf in enumerate(topic.leafs or []):
+ mermaid += indent + f"{letter(idx_ch)} --> {letter(idx_ch)}{idx_topicleaf + 1}[{quote(topic_leaf)}]\n"
+
+ # SubTopic
+ for idx_sub, sub in enumerate(topic.sub_tocpics or []):
+ mermaid += indent + f"{letter(idx_ch)} --> {letter(idx_ch)}{idx_sub + 1}[{quote(sub.title)}]\n"
+ # subtopic leafs
+ for idx_subleaf, subleaf in enumerate(sub.leafs or []):
+ mermaid += indent + f"{letter(idx_ch)}{idx_sub + 1} --> {letter(idx_ch)}{idx_sub + 1}{idx_subleaf + 1}[{quote(subleaf)}]\n"
+
+ return mermaid.strip()
+
+
+async def publish_mermaid(mermaid: str) -> tuple[str, str]:
+ """Save Mermaid image to R2.
+
+ Returns:
+ (image_url, pako_url)
+ """
+ b64_str = base64.urlsafe_b64encode(mermaid.encode("utf-8")).decode("ascii")
+ save_path = Path(DOWNLOAD_DIR) / f"{digest(mermaid)}.jpg" # noqa: S324
+ r2_key = f"TTL/365d/{save_path.name}"
+ img_url = f"{DB.CF_R2_PUBLIC_URL}/{r2_key}"
+ if await download_file(f"https://mermaid.ink/img/{b64_str}?type=jpeg&theme=forest&width=2160", path=save_path, suffix=".jpg"):
+ img_url = await shorten_url(img_url, alias=str(digest(mermaid, 16)))
+ mermaid = mermaid.replace("\ngraph LR", f"\n%% {img_url}\ngraph LR")
+ # generate pako url for mermaid image
+ json_str = json.dumps({"code": mermaid.strip()}, separators=(",", ":"))
+ compressed_bytes = zlib.compress(json_str.encode("utf-8"), level=9)
+ pako_b64_str = base64.urlsafe_b64encode(compressed_bytes).decode("utf-8").rstrip("=")
+ pako_url = await shorten_url(f"https://mermaid.live/view#pako:{pako_b64_str}", alias=str(digest(pako_b64_str, 16)))
+
+ if save_path.is_file():
+ await set_cf_r2(r2_key, data=save_path.read_bytes(), mime_type="image/jpeg", silent=True)
+ save_path.unlink(missing_ok=True)
+ return img_url, pako_url
+ return "", ""
src/ytdlp/download.py
@@ -81,7 +81,7 @@ async def ytdlp_download(
for fmt_id in [x.strip() for x in format_id.split("+") if x.strip()]: # ['299', '140']
video_ext = info["video_path"].suffix # .mp4
Path(DOWNLOAD_DIR).joinpath(f"{info['id']}.f{fmt_id}{video_ext}").unlink(missing_ok=True)
- # summary
+
await modify_progress(text=msg.strip(), force_update=True, **kwargs)
return info
src/ytdlp/main.py
@@ -10,7 +10,6 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from ai.summary import summarize
from config import AI, ASR, CAPTION_LENGTH, MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES
from database.r2 import get_cf_r2
from messages.database import copy_messages_from_db, save_messages
@@ -22,7 +21,8 @@ from multimedia import convert_to_h264
from preview.bilibili import get_bilibili_comments, get_bilibili_vinfo, make_bvid_clickable
from preview.youtube import get_youtube_comments, get_youtube_vinfo
from publish import publish_telegraph
-from utils import readable_size, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
+from summarize.summarize import summarize
+from utils import convert2html, readable_size, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
from ytdlp.download import ytdlp_download
from ytdlp.utils import append_tag, cleanup_ytdlp, generate_prompt, get_subtitles, platform_emoji
@@ -46,8 +46,8 @@ async def preview_ytdlp(
ytdlp_video_target: str | int | None = None,
ytdlp_audio_target: str | int | None = None,
ytdlp_send_subtitle: bool = False,
- ytdlp_send_summary: bool = False,
- summary_model_id: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
+ summary_ytdlp: bool = False,
+ summary_ytdlp_model: str = AI.SUBTITLE_SUMMARY_MODEL_ALIAS,
enable_corrector: bool = True,
show_author: bool = True,
show_title: bool = True,
@@ -76,8 +76,8 @@ async def preview_ytdlp(
ytdlp_video_target (str | int, optional): The target chat id to send video.
ytdlp_audio_target (str | int, optional): The target chat id to send audio.
ytdlp_send_subtitle (bool, optional): Send subtitle. Defaults to False.
- ytdlp_send_summary (bool, optional): Send AI summary. Defaults to False.
- summary_model_id (str, optional): The model id to use for AI summary.
+ summary_ytdlp (bool, optional): Send AI summary. Defaults to False.
+ summary_ytdlp_model (str, optional): The model id to use for AI summary.
to_telegraph (bool, optional): Whether to publish the subtitle or transcription to telegraph.
"""
logger.trace(f"{url=} {kwargs=}")
@@ -138,23 +138,30 @@ async def preview_ytdlp(
# get subtitles
subtitles = ""
- if true(ytdlp_send_subtitle) or true(ytdlp_send_summary):
+ if true(ytdlp_send_subtitle) or true(summary_ytdlp):
fpath = info["audio_path"] if info["audio_path"].is_file() else info["video_path"]
asr_engine = kwargs.get("asr_engine", "uncensored") if platform == "youtube" else ASR.DEFAULT_ENGINE
subtitles = await get_subtitles(fpath, url, asr_engine, info, enable_corrector=enable_corrector)
# get ai summary
telegraph_ai = ""
- if subtitles and true(ytdlp_send_summary):
+ if subtitles and true(summary_ytdlp):
+ desc = info.get("description", "")
+ desc_html = desc if desc.startswith("<") else convert2html(desc)
+ if platform == "bilibili":
+ desc_html = f'<iframe src="https://player.bilibili.com/player.html?isOutside=true&bvid={bvid}&p=1&autoplay=0&poster=1&danmaku=1" frameborder="0" scrolling="no" border="0" framespacing="0" allowfullscreen="true" style="width: 100%; aspect-ratio: 16/9;"></iframe>{desc_html}'
+ elif platform == "youtube":
+ desc_html = f'<iframe src="https://www.youtube.com/embed/{vid}" frameborder="0" scrolling="no" border="0" framespacing="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen="true" style="width: 100%; aspect-ratio: 16/9;"></iframe>{desc_html}'
+ desc_page = {"emoji": "🎬", "name": "视频详情", "html": desc_html}
summary = await summarize(
- transcripts=subtitles,
- reference=generate_prompt(info),
- model=summary_model_id,
+ sources=[{"type": "system_prompt", "text": generate_prompt(info)}, {"type": "transcripts", "text": subtitles}],
+ model=summary_ytdlp_model,
title=info.get("title"),
- description=info.get("description"),
+ description=desc_page,
author=info.get("author"),
url=url,
date=glom(info, Coalesce("pubdate", "upload_date"), default=""),
+ min_text_length=200,
)
telegraph_ai = summary.get("telegraph_url", "")
src/ytdlp/utils.py
@@ -12,7 +12,7 @@ from pyrogram.types import Message
from yt_dlp.utils import YoutubeDLError
from asr.voice_recognition import asr_file
-from config import CAPTION_LENGTH, COOKIE, DOWNLOAD_DIR, PROXY
+from config import CAPTION_LENGTH, DOWNLOAD_DIR, PROXY
from cookies import ytdlp_bilibili_cookie
from messages.utils import smart_split
from multimedia import convert_img_to_telegram_format, generate_cover
@@ -79,7 +79,7 @@ async def get_ytdlp_opts(platform: Literal["youtube", "bilibili", "ytdlp"] | Non
"color": "no_color-tty",
"logger": logger,
}
- if platform == "bilibili" and COOKIE.YTDLP_BILIBILI_USE_COOKIE:
+ if platform == "bilibili":
cookiefile = await ytdlp_bilibili_cookie()
logger.trace(f"Use cookie file: {cookiefile}")
ytdlp_opts["cookiefile"] = cookiefile
src/config.py
@@ -180,7 +180,6 @@ class COOKIE: # See: https://github.com/easychen/CookieCloud
CLOUD_SERVER = os.getenv("COOKIE_CLOUD_SERVER", "")
CLOUD_KEY = os.getenv("COOKIE_CLOUD_KEY", "")
CLOUD_PASS = os.getenv("COOKIE_CLOUD_PASS", "")
- YTDLP_BILIBILI_USE_COOKIE = os.getenv("YTDLP_BILIBILI_USE_COOKIE", "0").lower() in ["1", "y", "yes", "t", "true", "on"]
class TID: # see more TID usecase in `src/permission.py`
src/networking.py
@@ -464,12 +464,10 @@ async def match_social_media_link(text: str, *, flatten_first: bool = True) -> d
# https://arxiv.org/abs/2301.12345
# https://arxiv.org/pdf/2301.12345v3
if matched := re.search(r"(https?://)?arxiv\.org/(abs|pdf)/(\d{4}\.\d{4,5}(?:v\d+)?)", text):
- url = matched.group(0)
arxiv_id = matched.group(3)
if "v" not in arxiv_id:
arxiv_id += "v1"
- url += "v1"
- return {"url": url, "arxiv_id": arxiv_id, "db_key": f"arxiv.org/abs/{arxiv_id}", "platform": "arxiv"}
+ return {"url": f"https://arxiv.org/abs/{arxiv_id}", "arxiv_id": arxiv_id, "db_key": f"arxiv.org/abs/{arxiv_id}", "platform": "arxiv"}
# if all above pre-defined patterns failed, try to match ytdlp link
if urls := match_urls(text):
src/permission.py
@@ -65,6 +65,7 @@ def set_permission(message: Message) -> dict:
"google_search": True,
"show_progress": True,
"detail_progress": True,
+ "social": True,
"douyin": True,
"tiktok": True,
"instagram": True,
src/publish.py
@@ -144,8 +144,10 @@ async def publish_neocities(title: str, html: str | None = None, author: str | N
return f"https://t.me/iv?url={quote_plus(pub_url)}&rhash={TOKEN.NEOCITIES_IV_HASH}" if TOKEN.NEOCITIES_IV_HASH else pub_url
-async def telegraph_aipage(page: AIPage, ttl: str | None = None) -> str:
+async def telegraph_aipage(page: AIPage, ttl: str | None = None, *, force_r2: bool = False) -> str:
"""Publish AI Page to Telegraph."""
+ if force_r2:
+ return await r2_aipage(page, ttl=ttl, rformat="url")
anchor = lambda s: s.replace(" ", "-")
nodes = []
@@ -171,13 +173,21 @@ async def telegraph_aipage(page: AIPage, ttl: str | None = None) -> str:
nodes.append({"tag": "p", "children": [overview]})
# Description
- if page.description:
- desc = convert2md(html=page.description)
+ description = page.description
+ if description and isinstance(description, str):
+ desc = convert2md(html=description)
desc_html = convert2html(remove_consecutive_newlines(desc, newline_level=2))
- desc_nodes = html_to_nodes(desc_html)
+ desc_nodes = html_to_nodes(adjust_tags(desc_html))
nodes.append({"tag": "h4", "children": ["📖原始简介"]})
nodes.extend(desc_nodes)
-
+ elif description and isinstance(description, dict) and description.get("html"):
+ desc_emoji = description.get("emoji", "📖")
+ desc_title = description.get("name", "原始简介")
+ desc_md = convert2md(html=description["html"]) if description["html"].startswith("<") else description["html"]
+ desc_html = convert2html(remove_consecutive_newlines(desc_md, newline_level=2))
+ desc_nodes = html_to_nodes(adjust_tags(desc_html))
+ nodes.append({"tag": "h4", "children": [desc_emoji + desc_title]})
+ nodes.extend(desc_nodes)
# Sections
for section in sections:
nodes.append({"tag": "h4", "children": [section.emoji + section.title]})
@@ -222,11 +232,18 @@ async def r2_aipage(page: AIPage, ttl: str | None = None, *, expand_transcript:
desc_tag = ""
desc_head = ""
- if page.description:
- sidebars += """<li><a href="#description" onclick="navClick(event)"><span class="sidebar-icon">📖</span><span class="sidebar-label">原始简介</span></a></li>"""
- desc_html = page.description if page.description.startswith("<") else convert2html(page.description)
+ description = page.description
+ if description and isinstance(description, str):
+ sidebars += '<li><a href="#description" onclick="navClick(event)"><span class="sidebar-icon">📖</span><span class="sidebar-label">原始简介</span></a></li>'
+ desc_html = description if description.startswith("<") else convert2html(description)
desc_tag = f'<div class="card description"><div class="card-label" id="description">📖原始简介</div>{desc_html}</div>'
- if page.description and overview:
+ elif description and isinstance(description, dict) and description.get("html"):
+ desc_emoji = description.get("emoji", "📖")
+ desc_title = description.get("name", "原始简介")
+ desc_html = description["html"] if description["html"].startswith("<") else convert2html(description["html"])
+ sidebars += f'<li><a href="#description" onclick="navClick(event)"><span class="sidebar-icon">{desc_emoji}</span><span class="sidebar-label">{desc_title}</span></a></li>'
+ desc_tag = f'<div class="card description"><div class="card-label" id="description">{desc_emoji}{desc_title}</div>{desc_html}</div>'
+ if description and overview:
desc_head = f"""<meta property="og:description" content="{glom(page, "summary.overview", default="")}">"""
sections_tag = ""
@@ -247,8 +264,8 @@ async def r2_aipage(page: AIPage, ttl: str | None = None, *, expand_transcript:
transcriptions = ""
if transcripts:
- sidebars += """<li><a href="#transcript" onclick="navClick(event)"><span class="sidebar-icon">🔤</span><span class="sidebar-label">完整字幕</span></a></li>"""
- transcriptions += '<div class="card" id="transcript" style="margin-top: 24px;"><button class="transcript-toggle" aria-expanded="false" onclick="toggleTranscript(this)">展开字幕 <span class="arrow">▾</span></button><div class="transcript-content" id="transcriptions">'
+ sidebars += """<li class="nav-transcript"><a href="#transcript" onclick="navClick(event)"><span class="sidebar-icon">🔤</span><span class="sidebar-label">完整字幕</span></a></li>"""
+ transcriptions += '<div class="card transcript-card" id="transcript"><button class="transcript-toggle" aria-expanded="false" onclick="toggleTranscript(this)">展开字幕 <span class="arrow">▾</span></button><div class="transcript-content" id="transcriptions">'
for sentence in transcripts:
transcriptions += f'<p><span class="ts">{sentence.start}</span>{sentence.content}</p>'
transcriptions += "</div></div>"
@@ -266,11 +283,29 @@ async def r2_aipage(page: AIPage, ttl: str | None = None, *, expand_transcript:
theme_icon = '<button class="icon-theme" id="icon-theme" onclick="toggleTheme()" aria-label="切换主题"><svg class="icon-sun" viewBox="0 0 24 24"><circle cx="12" cy="12" r="5" /><path d="M12 1v2M12 21v2M4.22 4.22l1.42 1.42M18.36 18.36l1.42 1.42M1 12h2M21 12h2M4.22 19.78l1.42-1.42M18.36 5.64l1.42-1.42" /></svg><svg class="icon-moon" viewBox="0 0 24 24" style="display:none"><path d="M21 12.79A9 9 0 1 1 11.21 3 7 7 0 0 0 21 12.79z" /></svg></button>'
mermaid_icon = '<button class="icon-mindmap" id="icon-mindmap" onclick="toggleMindmapPanel()" aria-label="思维导图"><svg viewBox="0 0 24 24"><circle cx="4" cy="12" r="2" /><path d="M6 12h6M12 12l8-8M12 12h8M12 12l8 8" /></svg></button>'if page.mermaid_img else "" # fmt: skip
- mermaid_desktop = f'<div class="mindmap-panel" id="mindmap-panel"><div class="mindmap-panel-content"><img src="{page.mermaid_img}" alt="思维导图"><a href="{page.mermaid_url}" target="_blank" class="mindmap-link">查看完整思维导图</a></div></div>' if page.mermaid_img else "" # fmt: skip
-
- mermaid_mobile = f'<div class="card mindmap-card mindmap-mobile" id="mindmap-mobile"><button class="transcript-toggle" aria-expanded="false" onclick="toggleMindmap(this)">展开思维导图 <span class="arrow">▾</span></button><div class="mindmap-body" id="mindmap-body"><img src="{page.mermaid_img}" alt="思维导图"><a href="{page.mermaid_url}" target="_blank" class="mindmap-link">查看完整思维导图</a></div></div>' if page.mermaid_img else "" # fmt: skip
-
- html_str = f"""<!DOCTYPE html><html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><meta property="article:published_time" content="{utc_date:%Y-%m-%dT%H:%M:%SZ}"><meta property="og:title" content="{page.title}">{desc_head}<title>{page.title}</title><link rel="stylesheet" href="{DB.CF_R2_PUBLIC_URL}/telegraph.css"><script src="{DB.CF_R2_PUBLIC_URL}/telegraph.js" defer></script></head><body>
+ mermaid_desktop = f'<div class="mindmap-panel" id="mindmap-panel"><div class="mindmap-panel-content"><img data-src="{page.mermaid_img}" alt="思维导图"><a href="{page.mermaid_url}" target="_blank" class="mindmap-link">查看完整思维导图</a></div></div>' if page.mermaid_img else "" # fmt: skip
+
+ mermaid_mobile = f'<div class="card mindmap-card mindmap-mobile" id="mindmap-mobile"><button class="transcript-toggle" aria-expanded="false" onclick="toggleMindmap(this)">展开思维导图 <span class="arrow">▾</span></button><div class="mindmap-body" id="mindmap-body"><img data-src="{page.mermaid_img}" alt="思维导图"><a href="{page.mermaid_url}" target="_blank" class="mindmap-link">查看完整思维导图</a></div></div>' if page.mermaid_img else "" # fmt: skip
+
+ html_str = f"""<!DOCTYPE html>
+<html lang="zh-CN">
+<head>
+ <meta charset="UTF-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <meta property="article:published_time" content="{utc_date:%Y-%m-%dT%H:%M:%SZ}">
+ <meta property="og:title" content="{page.title}">{desc_head}
+ <meta property="og:site_name" content="🤖AI导读" />
+ <link rel="icon" type="image/png" href="{DB.CF_R2_PUBLIC_URL}/favicon/favicon-96x96.png" sizes="96x96" />
+ <link rel="icon" type="image/svg+xml" href="{DB.CF_R2_PUBLIC_URL}/favicon/favicon.svg" />
+ <link rel="shortcut icon" href="{DB.CF_R2_PUBLIC_URL}/favicon/favicon.ico" />
+ <link rel="apple-touch-icon" sizes="180x180" href="{DB.CF_R2_PUBLIC_URL}/favicon/apple-touch-icon.png" />
+ <meta name="apple-mobile-web-app-title" content="R2" />
+ <link rel="manifest" href="{DB.CF_R2_PUBLIC_URL}/favicon/site.webmanifest" />
+ <title>{page.title}</title>
+ <link rel="stylesheet" href="{DB.CF_R2_PUBLIC_URL}/telegraph.css">
+ <script src="{DB.CF_R2_PUBLIC_URL}/telegraph.js" defer></script>
+</head>
+<body>
<!-- Icon -->
{sidebar_icon}
@@ -280,6 +315,8 @@ async def r2_aipage(page: AIPage, ttl: str | None = None, *, expand_transcript:
{theme_icon}
{sidebars}
+ <div class="resize-handle resize-handle-left" id="resize-handle-left"></div>
+ <div class="resize-handle resize-handle-right" id="resize-handle-right"></div>
<div class="container">
<header class="header"><h1 class="header-title"><a href="{url}" target="_blank">{page.title}</a></h1>{author_tag}</header>
@@ -299,7 +336,7 @@ async def r2_aipage(page: AIPage, ttl: str | None = None, *, expand_transcript:
{transcriptions}
</div>
-
+ {'<div class="transcript-panel" id="transcript-panel"></div>' if transcriptions else ""}
{mermaid_desktop}
</body>
src/schema.py
@@ -5,103 +5,32 @@ import jsonref
from pydantic import BaseModel, ConfigDict, Field
-def mermaid_syntax() -> str:
- return """
-# Mermaid Graph - Basic Syntax
+class SubTopic(BaseModel):
+ """思维导图二级话题."""
-Graph is composed of **nodes** (geometric shapes) and **edges** (arrows or lines). The Mermaid code defines how nodes and edges are made and accommodates different arrow types, multi-directional arrows, and any linking to and from subgraphs.
+ model_config = ConfigDict(extra="ignore")
-## A node (default)
+ title: str = Field(description="二级话题标题")
+ leafs: list[str] = Field(default=[], description="叶子节点内容列表")
-```mermaid
-graph LR
- id
-```
-```note
-The id is what is displayed in the box.
-```
+class Topic(BaseModel):
+ """思维导图一级话题."""
-### A node with text
+ model_config = ConfigDict(extra="ignore")
-It is also possible to set text in the box that differs from the id. If this is done several times, it is the last text
-found for the node that will be used. Also if you define edges for the node later on, you can omit text definitions. The
-one previously defined will be used when rendering the box.
+ title: str = Field(description="一级话题标题")
+ sub_tocpics: list[SubTopic] | None = Field(default=None, description="二级话题列表(可为空)")
+ leafs: list[str] | None = Field(default=None, description="一级话题的叶子节点内容列表(可为空)")
-```mermaid
-graph LR
- id1[This is the text in the box]
-```
-## Node shapes
+class MindMap(BaseModel):
+ """思维导图根节点."""
-### A node with round edges
+ model_config = ConfigDict(extra="ignore")
-```mermaid
-graph LR
- id1(This is the text in the box)
-```
-
-## Links between nodes
-
-Nodes can be connected with links/edges. It is possible to have different types of links or attach a text string to a link.
-
-### A link with arrow head
-
-```mermaid
-graph LR
- A-->B
-```
-
-### An open link
-
-```mermaid
-graph LR
- A --- B
-```
-
-### Text on links
-
-```mermaid
-graph LR
- A---|This is the text|B
-```
-
-### A link with arrow head and text
-
-```mermaid
-graph LR
- A-->|text|B
-```
-
-### Dotted link
-
-```mermaid
-graph LR
- A-.->B;
-```
-
-### Dotted link with text
-
-```mermaid
-graph LR
- A-. text .-> B
-```
-
-### Thick link
-
-```mermaid
-graph LR
- A ==> B
-```
-
-### Thick link with text
-
-```mermaid
-graph LR
- A == text ==> B
-```
-""".strip()
+ main_title: str = Field(description="根节点标题")
+ topics: list[Topic] = Field(description="一级话题列表")
class Section(BaseModel):
@@ -124,14 +53,7 @@ class ContentExtraction(BaseModel):
title="分片内容",
description="需将文档划分为逻辑连贯的片段(如按章节、主题、时间线划分);每个片段需拟定**简洁准确**的标题(体现片段核心)、匹配1个相关emoji;并说明该片段的核心内容。",
)
- mermaid: str = Field(
- title="思维导图",
- pattern=r"^graph LR",
- description=f"以Mermaid graph格式表示的全文思维导图,以'graph LR'开头。需清晰呈现文档的逻辑结构(如核心主题→子主题→关键观点/结论),节点层级明确,便于用户快速梳理文档框架。\n{mermaid_syntax()}",
- examples=[
- "graph LR\n A[核心主题] --> B[子标题1]\n A --> C[子标题2]\n A --> D[子标题3]\n A --> E[子标题4]\n\n\n B --> B1[二级标题1-1]\n B --> B2[二级标题1-2]\n B1 --> B11[核心观点1-1-1]\n B1 --> B12[核心观点1-1-2]\n B2 --> B21[争议点1-2-1]\n B2 --> B22[争议点1-2-2]\n\n\n C --> C1[关键数据2-1]\n C --> C2[主要结论2-1]\n C --> C3[补充结论2-2]\n\n D --> D1[核心问题3-1]\n D --> D2[潜在风险3-2]\n D --> D3[影响因素3-3]\n\n E --> E1[发展趋势4-1]\n E --> E2[行动建议4-2]\n E --> E3[未来结论4-3]",
- ],
- )
+ mindmap: MindMap | None = Field(default=None, description="思维导图")
class Sentence(BaseModel):
@@ -145,7 +67,7 @@ class AIPage(BaseModel):
title: str = Field(default="AI导读", description="标题")
url: str | None = Field(default=None, description="原始链接")
author: str | None = Field(default=None, description="作者")
- description: str | None = Field(default=None, description="原始描述")
+ description: str | dict | None = Field(default=None, description="原始描述")
date: datetime | None = Field(default_factory=lambda: datetime.now(UTC), description="发布日期")
summary: ContentExtraction | None = Field(default=None, description="AI总结")
transcripts: str | list[Sentence] | None = Field(default=None, description="转录稿")
src/utils.py
@@ -121,11 +121,47 @@ def true(value: Any) -> bool:
return True
-def digest(s: Any, length: int = 32) -> str:
- raw_bytes = hashlib.shake_256(str(s).encode()).digest(length * 2)
+def sanitize_filename(filename: str, replacement: str = "_") -> str:
+ # 替换所有跨平台非法字符和ASCII控制字符(0-31)
+ illegal_chars = re.compile(r'[\\/:*?"<>|\x00-\x1f]')
+ cleaned = illegal_chars.sub(replacement, filename)
+
+ # 处理Windows特有的跨平台兼容性问题
+ # 1. 移除文件名末尾的点和空格(Windows会自动忽略,导致跨平台问题)
+ cleaned = cleaned.rstrip(". ")
+
+ # 2. 处理Windows系统保留名称(即使在Linux/macOS创建,Windows也无法访问)
+ reserved_names = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9"}
+
+ # 检查是否为纯保留名称(不包含扩展名)
+ upper_cleaned = cleaned.upper()
+ if upper_cleaned in reserved_names:
+ cleaned = f"{cleaned}{replacement}"
+
+ # 处理清理后为空的情况
+ if not cleaned:
+ cleaned = replacement
+
+ return cleaned
+
+
+def digest(s: Any, length: int = 32, *, to_int: bool = False) -> str | int:
+ hasher = hashlib.shake_256()
+ if isinstance(s, Path) and s.is_file():
+ with open(s, "rb") as f:
+ for chunk in iter(lambda: f.read(65536), b""):
+ hasher.update(chunk)
+ elif isinstance(s, bytes):
+ hasher.update(s)
+ else:
+ hasher.update(str(s).encode())
+
+ raw_bytes = hasher.digest(length * 2)
b64_str = base64.urlsafe_b64encode(raw_bytes).decode("ascii")
b64_str = b64_str.replace("=", "").replace("-", "").replace("_", "")
- return b64_str[:length]
+ if to_int:
+ b64_str = int.from_bytes(b64_str.encode("ascii"), byteorder="big")
+ return str(b64_str)[:length]
def remove_none_values(d: dict | list) -> dict:
@@ -608,7 +644,7 @@ def cleanup_old_files(root: Path | str | None = None, duration: int = CLEAN_OLD_
for path in root.glob("*"):
if not path.is_file():
continue
- if all(now - x > duration for x in [path.stat().st_atime, path.stat().st_ctime, path.stat().st_mtime]):
+ if all(now - x > duration for x in [path.stat().st_atime, path.stat().st_mtime]):
logger.warning(f"Deleting old file: {path}")
path.unlink(missing_ok=True)
@@ -633,6 +669,8 @@ def convert2md(*, html: str | None = None, path: str | Path | None = None) -> st
def convert2html(texts: str = "") -> str:
"""Convert texts to html format."""
+ if not isinstance(texts, str) or not str(texts).strip():
+ return ""
texts = markdown.markdown(texts)
return texts.replace("\n", "<br>")