Commit 968a928
Changed files (7)
src
src/asr/corrector.py
@@ -0,0 +1,129 @@
+import json
+import re
+from contextlib import suppress
+
+from pyrogram.types import Chat, Message
+from pyrogram.types.messages_and_media.message import Str
+
+from ai.main import ai_text_generation
+from config import PREFIX
+from utils import rand_number
+
+# ruff: noqa: RUF001
+JSON_SCHEMA = {
+ "title": "List of Correction",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "title": "Correction",
+ "properties": {
+ "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
+ "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
+ },
+ "required": ["idx", "corrected"],
+ "additionalProperties": False,
+ },
+}
+
+
+async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
+ """Correct ASR results.
+
+ Example:
+ [00:00] hello
+ [00:01] world
+
+ Args:
+ inputs (str): original ASR results.
+
+ Returns:
+ str: corrected ASR results.
+ """
+ SYSTEM_PROMPT = """# 身份与职责
+你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
+
+# 校对规则
+## 必做事项
+1. 逐行检查提供的转录文稿中的每一项,识别两类错误:
+ - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
+ - 口语重复:无意义的重复表述(如“这个这个方案”)
+ - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
+2. 仅保留错误项,正确项不纳入输出
+
+## 约束条件
+1. 保留原始文本中的emoji表情
+2. 不处理除指定三类错误之外的其他错误(如逻辑错误)
+
+# 输入处理
+1. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
+2. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
+3. 若输入为空或格式错误,输出空列表
+
+# 执行步骤
+1. 初始化空列表用于存储错误项
+2. 遍历转录稿中的每一项:
+ a. 检查text字段是否存在转录错误
+ b. 检查text字段是否存在口语重复
+ c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
+
+# 输出规范
+1. 输出格式为JSON数组,每项包含idx和corrected两个字段
+2. 仅输出存在错误的项,正确项不显示
+3. 语言保持与原始文本一致的口语化风格
+4. 错误项数量无限制,完整呈现所有识别到的错误
+
+示例输入:
+[
+ {"idx": 0, "corrected": "平果"},
+ {"idx": 1, "corrected": "这个这个方案"}
+]
+
+示例输出:
+[
+ {"idx": 0, "corrected": "苹果"},
+ {"idx": 1, "corrected": "这个方案"}
+]
+"""
+ if reference:
+ SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
+ if not inputs:
+ return inputs
+ # match [mm:ss] or [hh:mm:ss]
+ pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
+ matches = re.findall(pattern, inputs)
+ texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
+ ai = await ai_text_generation(
+ "fake-client", # type: ignore
+ Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
+ openai_responses_config={
+ "instructions": SYSTEM_PROMPT,
+ "max_output_tokens": 65536,
+ "extra_body": {"thinking": {"type": "enabled"}},
+ "tools": [{"type": "web_search", "max_keyword": 5, "limit": 10}],
+ "max_tool_calls": 10,
+ "text": {
+ "format": {
+ "type": "json_schema",
+ "name": "ASRCorrection",
+ "strict": True,
+ "description": "A list of ASR correction",
+ "schema": JSON_SCHEMA,
+ }
+ },
+ },
+ gemini_generate_content_config={
+ "max_output_tokens": 65536,
+ "thinking_config": {"include_thoughts": True, "thinking_level": "high"},
+ "tools": [{"google_search": {}}, {"code_execution": {}}],
+ "system_instruction": SYSTEM_PROMPT,
+ "responseMimeType": "application/json",
+ "responseJsonSchema": JSON_SCHEMA,
+ },
+ silent=True,
+ )
+ with suppress(Exception):
+ for output in json.loads(ai["texts"]):
+ idx = output["idx"]
+ matches[idx] = (matches[idx][0], output["corrected"])
+ return "\n".join([f"{item[0]} {item[1]}" for item in matches])
+ return inputs
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.types import Message
from asr.ali import ali_asr
from asr.cloudflare import cloudflare_asr
+from asr.corrector import asr_corrector
from asr.deepgram import deepgram_asr
from asr.gemini import gemini_asr
from asr.groq import groq_asr
@@ -191,6 +192,9 @@ async def asr_file(
prompt: str = "",
*,
tencent_language: str = "16k_zh-PY",
+ enable_corrector: bool = False,
+ corrector_model: str = "asr-corrector",
+ corrector_reference: str | None = None,
delete_local_file: bool = True,
delete_gemini_file: bool = True,
**kwargs,
@@ -201,8 +205,7 @@ async def asr_file(
return {"error": f"{path} is not exist"}
duration = audio_duration(path)
engine = auto_choose_asr_engine(duration=duration, engine=engine)
-
- log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)}\n{path.name}"
+ log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
logger.debug(log)
await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
res = {}
@@ -232,4 +235,6 @@ async def asr_file(
path.unlink(missing_ok=True)
elif path.is_file():
res["audio_file"] = path
+ if enable_corrector or corrector_reference:
+ res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
return res
src/podcast/asr.py
@@ -9,11 +9,13 @@ from glom import Coalesce, glom
from asr.utils import audio_duration
from asr.voice_recognition import asr_file
from config import PODCAST
+from messages.utils import remove_img_tag
from networking import match_social_media_link
+from podcast.utils import get_pubdate
from preview.bilibili import get_bilibili_vinfo
from preview.youtube import get_youtube_vinfo
from subtitles.base import fetch_subtitle
-from utils import rand_string, readable_time, strings_list
+from utils import convert_md, rand_string, readable_time, remove_consecutive_newlines, strings_list
async def get_transcripts(
@@ -27,8 +29,12 @@ async def get_transcripts(
If the link of this entry has embedded subtitles (YouTube, Bilibili links), use it directly.
Otherwise, generate the transcript via ASR.
"""
+ desc = convert_md(html=glom(entry, Coalesce("content.0.value", "summary"), default=""))
+ desc, _ = remove_img_tag(desc)
+ desc = remove_consecutive_newlines(desc, newline_level=2)
+ reference = f"本次转录稿为播客栏目《{feed_title}》的一期节目。\n该期节目标题: [{entry['title']}]({entry['link']})\n播出日期: {get_pubdate(entry):%Y-%m-%d}\n节目简介: {desc}"
if urlparse(entry["link"]).netloc in ["www.youtube.com", "www.bilibili.com"]: # get subtitle from API first
- res = await fetch_subtitle(entry["link"])
+ res = await fetch_subtitle(entry["link"], reference=reference)
if res.get("subtitles"):
return res["subtitles"]
@@ -37,10 +43,9 @@ async def get_transcripts(
# So we need to copy the file to another path before generating the transcript.
duration = await get_duration(audio_path, entry)
tmp_path = backup_audio(audio_path)
- desc = glom(entry, Coalesce("content.0.value", "summary"), default="")
prompt = f"请转录播客栏目《{feed_title}》的一期节目的音频。\n该期节目标题: {entry['title']}\n节目时长: {readable_time(duration)}\n节目简介: {desc}"
engine = get_asr_engine(feed_title, feed_url)
- asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, silent=True)
+ asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, corrector_reference=reference, silent=True)
Path(tmp_path).unlink(missing_ok=True)
return asr_res.get("texts", "")
src/subtitles/base.py
@@ -8,6 +8,7 @@ from pyrogram.types import Message
from youtube_transcript_api import IpBlocked, RequestBlocked, YouTubeTranscriptApi
from youtube_transcript_api.proxies import GenericProxyConfig
+from asr.corrector import asr_corrector
from config import PREFIX, PROXY, READING_SPEED, cache
from messages.parser import parse_msg
from messages.utils import startswith_prefix
@@ -49,7 +50,7 @@ async def match_url(client: Client, message: Message) -> str:
@cache.memoize(ttl=120)
-async def fetch_subtitle(url: str) -> dict:
+async def fetch_subtitle(url: str, reference: str = "") -> dict:
"""Fetch subtitles from Bilibili or YouTube.
Returns:
@@ -62,7 +63,11 @@ async def fetch_subtitle(url: str) -> dict:
subtitles = []
matched = await match_social_media_link(url)
if matched["platform"] == "bilibili":
- return await bilibili_subtitle_and_summary(url)
+ resp = await bilibili_subtitle_and_summary(url)
+ if resp.get("subtitles"):
+ resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
+ return resp
+
video_id = matched["vid"]
try:
proxy = GenericProxyConfig(http_url=PROXY.SUBTITLE, https_url=PROXY.SUBTITLE) if PROXY.SUBTITLE else None
@@ -76,7 +81,10 @@ async def fetch_subtitle(url: str) -> dict:
logger.error(f"Failed to get subtitle: {e}")
if not subtitles:
return {"error": "❌下载内嵌字幕失败\n🔄尝试使用语音转文字获取字幕"}
- return to_transcription(subtitles)
+ resp = to_transcription(subtitles)
+ if resp.get("subtitles"):
+ resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
+ return resp
def to_transcription(subtitles: list[dict]) -> dict:
src/subtitles/subtitle.py
@@ -84,7 +84,8 @@ async def get_subtitle(
this_info = parse_msg(message, silent=True)
reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
# Fetch subtitle via API
- res = await fetch_subtitle(url)
+ reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
+ res = await fetch_subtitle(url, reference)
if error := res.get("error", ""): # API failed
asr_engine = ASR.DEFAULT_ENGINE
if platform == "youtube": # bypass censorship
@@ -95,7 +96,7 @@ async def get_subtitle(
media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
fpath: str = await client.download_media(msg, media_path) # type: ignore
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
+ res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return
src/ytdlp/main.py
@@ -7,7 +7,6 @@ from typing import Literal
import markdown
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
-from glom import Coalesce, glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
@@ -27,7 +26,7 @@ from preview.youtube import get_youtube_comments, get_youtube_vinfo
from publish import publish_telegraph
from utils import count_subtitles, rand_number, readable_size, readable_time, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
from ytdlp.download import ytdlp_download
-from ytdlp.utils import append_subtitle, cleanup_ytdlp, get_subtitles, platform_emoji
+from ytdlp.utils import append_subtitle, cleanup_ytdlp, generate_prompt, get_subtitles, platform_emoji
async def preview_ytdlp(
@@ -135,13 +134,13 @@ async def preview_ytdlp(
if true(ytdlp_send_subtitle) or true(ytdlp_send_summary):
fpath = info["audio_path"] if info["audio_path"].is_file() else info["video_path"]
asr_engine = kwargs.get("asr_engine", "uncensored") if platform == "youtube" else ASR.DEFAULT_ENGINE
- if sub := await get_subtitles(fpath, url, asr_engine):
+ if sub := await get_subtitles(fpath, url, asr_engine, info):
subtitles = f"🔤<b>字幕:</b>\n{sub}"
# get ai summary
summary = ""
if subtitles and true(ytdlp_send_summary):
- prompt = generate_prompt(info)
+ prompt = generate_prompt(info, target="summary")
ai_msg = Message( # Construct a message for AI
id=rand_number(),
chat=message.chat,
@@ -265,23 +264,6 @@ async def generate_captions(
return results
-def generate_prompt(info: dict) -> str:
- """Generate prompt for AI summary."""
- prompt = f"以上是{info['extractor'].title()}视频"
- if author := info.get("author"):
- prompt += f"作者【{author}】"
- prompt += "的一期节目的文字稿。该期节目详情如下:\n"
- if title := info.get("title"):
- prompt += f"节目标题: {title}\n"
- if pubdate := glom(info, Coalesce("pubdate", "upload_date"), default=""):
- prompt += f"发布日期: {pubdate}\n"
-
- if desc := info.get("description"):
- prompt += f"节目简介: {desc}\n"
- prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
- return prompt
-
-
def get_target_chats(message: Message, video_target: str | int | None = None, audio_target: str | int | None = None, **kwargs) -> tuple[int | str, int | str]:
"""Get target chats of video and audio messages.
src/ytdlp/utils.py
@@ -5,7 +5,7 @@ from pathlib import Path
from typing import Literal
from urllib.parse import urlparse
-from glom import glom
+from glom import Coalesce, glom
from loguru import logger
from pyrogram.enums import ParseMode
from pyrogram.types import Message
@@ -190,15 +190,16 @@ def find_thumbnail(video_path: str | Path, audio_path: str | Path) -> str | None
return None
-async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str) -> str:
+async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict) -> str:
# send subtitles
subtitles = ""
matched = await match_social_media_link(url)
+ reference = generate_prompt(vinfo, "correction")
if matched["platform"] in ["bilibili", "youtube"]: # get subtitle from API first
- res = await fetch_subtitle(url=url)
+ res = await fetch_subtitle(url=url, reference=reference)
subtitles = res.get("subtitles", "") # only subtitles, no Bilibili's AI summary
if not subtitles:
- res = await asr_file(audio_path, asr_engine, silent=True)
+ res = await asr_file(audio_path, asr_engine, corrector_reference=reference, silent=True)
subtitles = res.get("texts", "")
if count_subtitles(subtitles) < 20:
subtitles = "" # ignore too short transcription
@@ -246,6 +247,22 @@ async def append_subtitle(name: str, sent_messages: dict) -> dict:
return modified
+def generate_prompt(info: dict, target: Literal["summary", "correction"]) -> str:
+ """Generate prompt for AI summary or correction."""
+ prompt = f"以上是{info['extractor'].title()}视频" if target == "summary" else f"本次转录稿为{info['extractor'].title()}平台"
+ if author := info.get("author"):
+ prompt += f"作者【{author}】"
+ prompt += "的一期节目的文字稿。该期节目详情如下:\n"
+ if title := info.get("title"):
+ prompt += f"节目标题: {title}\n"
+ if pubdate := glom(info, Coalesce("pubdate", "upload_date"), default=""):
+ prompt += f"发布日期: {pubdate}\n"
+ if desc := info.get("description"):
+ prompt += f"节目简介: {desc}\n"
+ prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头" if target == "summary" else ""
+ return prompt
+
+
def cleanup_ytdlp(vid: str):
if not vid:
return