Commit 968a928

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-04-10 06:21:46
feat(asr): add ASR corrector
1 parent 55d10fd
Changed files (7)
src/asr/corrector.py
@@ -0,0 +1,129 @@
+import json
+import re
+from contextlib import suppress
+
+from pyrogram.types import Chat, Message
+from pyrogram.types.messages_and_media.message import Str
+
+from ai.main import ai_text_generation
+from config import PREFIX
+from utils import rand_number
+
+# ruff: noqa: RUF001
+JSON_SCHEMA = {
+    "title": "List of Correction",
+    "type": "array",
+    "items": {
+        "type": "object",
+        "title": "Correction",
+        "properties": {
+            "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
+            "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
+        },
+        "required": ["idx", "corrected"],
+        "additionalProperties": False,
+    },
+}
+
+
+async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
+    """Correct ASR results.
+
+    Example:
+        [00:00] hello
+        [00:01] world
+
+    Args:
+        inputs (str): original ASR results.
+
+    Returns:
+        str: corrected ASR results.
+    """
+    SYSTEM_PROMPT = """# 身份与职责
+你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
+
+# 校对规则
+## 必做事项
+1. 逐行检查提供的转录文稿中的每一项,识别两类错误:
+   - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
+   - 口语重复:无意义的重复表述(如“这个这个方案”)
+   - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
+2. 仅保留错误项,正确项不纳入输出
+
+## 约束条件
+1. 保留原始文本中的emoji表情
+2. 不处理除指定三类错误之外的其他错误(如逻辑错误)
+
+# 输入处理
+1. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
+2. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
+3. 若输入为空或格式错误,输出空列表
+
+# 执行步骤
+1. 初始化空列表用于存储错误项
+2. 遍历转录稿中的每一项:
+   a. 检查text字段是否存在转录错误
+   b. 检查text字段是否存在口语重复
+   c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
+
+# 输出规范
+1. 输出格式为JSON数组,每项包含idx和corrected两个字段
+2. 仅输出存在错误的项,正确项不显示
+3. 语言保持与原始文本一致的口语化风格
+4. 错误项数量无限制,完整呈现所有识别到的错误
+
+示例输入:
+[
+  {"idx": 0, "corrected": "平果"},
+  {"idx": 1, "corrected": "这个这个方案"}
+]
+
+示例输出:
+[
+  {"idx": 0, "corrected": "苹果"},
+  {"idx": 1, "corrected": "这个方案"}
+]
+"""
+    if reference:
+        SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
+    if not inputs:
+        return inputs
+    # match [mm:ss] or [hh:mm:ss]
+    pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
+    matches = re.findall(pattern, inputs)
+    texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
+    ai = await ai_text_generation(
+        "fake-client",  # type: ignore
+        Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
+        openai_responses_config={
+            "instructions": SYSTEM_PROMPT,
+            "max_output_tokens": 65536,
+            "extra_body": {"thinking": {"type": "enabled"}},
+            "tools": [{"type": "web_search", "max_keyword": 5, "limit": 10}],
+            "max_tool_calls": 10,
+            "text": {
+                "format": {
+                    "type": "json_schema",
+                    "name": "ASRCorrection",
+                    "strict": True,
+                    "description": "A list of ASR correction",
+                    "schema": JSON_SCHEMA,
+                }
+            },
+        },
+        gemini_generate_content_config={
+            "max_output_tokens": 65536,
+            "thinking_config": {"include_thoughts": True, "thinking_level": "high"},
+            "tools": [{"google_search": {}}, {"code_execution": {}}],
+            "system_instruction": SYSTEM_PROMPT,
+            "responseMimeType": "application/json",
+            "responseJsonSchema": JSON_SCHEMA,
+        },
+        silent=True,
+    )
+    with suppress(Exception):
+        for output in json.loads(ai["texts"]):
+            idx = output["idx"]
+            matches[idx] = (matches[idx][0], output["corrected"])
+        return "\n".join([f"{item[0]} {item[1]}" for item in matches])
+    return inputs
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.types import Message
 
 from asr.ali import ali_asr
 from asr.cloudflare import cloudflare_asr
+from asr.corrector import asr_corrector
 from asr.deepgram import deepgram_asr
 from asr.gemini import gemini_asr
 from asr.groq import groq_asr
@@ -191,6 +192,9 @@ async def asr_file(
     prompt: str = "",
     *,
     tencent_language: str = "16k_zh-PY",
+    enable_corrector: bool = False,
+    corrector_model: str = "asr-corrector",
+    corrector_reference: str | None = None,
     delete_local_file: bool = True,
     delete_gemini_file: bool = True,
     **kwargs,
@@ -201,8 +205,7 @@ async def asr_file(
         return {"error": f"{path} is not exist"}
     duration = audio_duration(path)
     engine = auto_choose_asr_engine(duration=duration, engine=engine)
-
-    log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)}\n{path.name}"
+    log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
     logger.debug(log)
     await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
     res = {}
@@ -232,4 +235,6 @@ async def asr_file(
             path.unlink(missing_ok=True)
         elif path.is_file():
             res["audio_file"] = path
+    if enable_corrector or corrector_reference:
+        res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
     return res
src/podcast/asr.py
@@ -9,11 +9,13 @@ from glom import Coalesce, glom
 from asr.utils import audio_duration
 from asr.voice_recognition import asr_file
 from config import PODCAST
+from messages.utils import remove_img_tag
 from networking import match_social_media_link
+from podcast.utils import get_pubdate
 from preview.bilibili import get_bilibili_vinfo
 from preview.youtube import get_youtube_vinfo
 from subtitles.base import fetch_subtitle
-from utils import rand_string, readable_time, strings_list
+from utils import convert_md, rand_string, readable_time, remove_consecutive_newlines, strings_list
 
 
 async def get_transcripts(
@@ -27,8 +29,12 @@ async def get_transcripts(
     If the link of this entry has embedded subtitles (YouTube, Bilibili links), use it directly.
     Otherwise, generate the transcript via ASR.
     """
+    desc = convert_md(html=glom(entry, Coalesce("content.0.value", "summary"), default=""))
+    desc, _ = remove_img_tag(desc)
+    desc = remove_consecutive_newlines(desc, newline_level=2)
+    reference = f"本次转录稿为播客栏目《{feed_title}》的一期节目。\n该期节目标题: [{entry['title']}]({entry['link']})\n播出日期: {get_pubdate(entry):%Y-%m-%d}\n节目简介: {desc}"
     if urlparse(entry["link"]).netloc in ["www.youtube.com", "www.bilibili.com"]:  # get subtitle from API first
-        res = await fetch_subtitle(entry["link"])
+        res = await fetch_subtitle(entry["link"], reference=reference)
         if res.get("subtitles"):
             return res["subtitles"]
 
@@ -37,10 +43,9 @@ async def get_transcripts(
     # So we need to copy the file to another path before generating the transcript.
     duration = await get_duration(audio_path, entry)
     tmp_path = backup_audio(audio_path)
-    desc = glom(entry, Coalesce("content.0.value", "summary"), default="")
     prompt = f"请转录播客栏目《{feed_title}》的一期节目的音频。\n该期节目标题: {entry['title']}\n节目时长: {readable_time(duration)}\n节目简介: {desc}"
     engine = get_asr_engine(feed_title, feed_url)
-    asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, silent=True)
+    asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, corrector_reference=reference, silent=True)
     Path(tmp_path).unlink(missing_ok=True)
     return asr_res.get("texts", "")
 
src/subtitles/base.py
@@ -8,6 +8,7 @@ from pyrogram.types import Message
 from youtube_transcript_api import IpBlocked, RequestBlocked, YouTubeTranscriptApi
 from youtube_transcript_api.proxies import GenericProxyConfig
 
+from asr.corrector import asr_corrector
 from config import PREFIX, PROXY, READING_SPEED, cache
 from messages.parser import parse_msg
 from messages.utils import startswith_prefix
@@ -49,7 +50,7 @@ async def match_url(client: Client, message: Message) -> str:
 
 
 @cache.memoize(ttl=120)
-async def fetch_subtitle(url: str) -> dict:
+async def fetch_subtitle(url: str, reference: str = "") -> dict:
     """Fetch subtitles from Bilibili or YouTube.
 
     Returns:
@@ -62,7 +63,11 @@ async def fetch_subtitle(url: str) -> dict:
     subtitles = []
     matched = await match_social_media_link(url)
     if matched["platform"] == "bilibili":
-        return await bilibili_subtitle_and_summary(url)
+        resp = await bilibili_subtitle_and_summary(url)
+        if resp.get("subtitles"):
+            resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
+        return resp
+
     video_id = matched["vid"]
     try:
         proxy = GenericProxyConfig(http_url=PROXY.SUBTITLE, https_url=PROXY.SUBTITLE) if PROXY.SUBTITLE else None
@@ -76,7 +81,10 @@ async def fetch_subtitle(url: str) -> dict:
         logger.error(f"Failed to get subtitle: {e}")
     if not subtitles:
         return {"error": "❌下载内嵌字幕失败\n🔄尝试使用语音转文字获取字幕"}
-    return to_transcription(subtitles)
+    resp = to_transcription(subtitles)
+    if resp.get("subtitles"):
+        resp["subtitles"] = await asr_corrector(resp["subtitles"], reference)
+    return resp
 
 
 def to_transcription(subtitles: list[dict]) -> dict:
src/subtitles/subtitle.py
@@ -84,7 +84,8 @@ async def get_subtitle(
     this_info = parse_msg(message, silent=True)
     reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
     # Fetch subtitle via API
-    res = await fetch_subtitle(url)
+    reference = f"本次转录稿为{matched['platform'].title()}平台作者【{vinfo['author']}】的一期节目。\n该期节目标题: [{vinfo['title']}]({url})\n播出日期: {vinfo['pubdate']}\n节目简介: {description}"
+    res = await fetch_subtitle(url, reference)
     if error := res.get("error", ""):  # API failed
         asr_engine = ASR.DEFAULT_ENGINE
         if platform == "youtube":  # bypass censorship
@@ -95,7 +96,7 @@ async def get_subtitle(
             media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
             fpath: str = await client.download_media(msg, media_path)  # type: ignore
             prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
-            res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
+            res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, corrector_reference=reference, silent=True, **kwargs)
             if res.get("error"):
                 await modify_progress(text=res["error"], force_update=True, **kwargs)
                 return
src/ytdlp/main.py
@@ -7,7 +7,6 @@ from typing import Literal
 
 import markdown
 from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
-from glom import Coalesce, glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
@@ -27,7 +26,7 @@ from preview.youtube import get_youtube_comments, get_youtube_vinfo
 from publish import publish_telegraph
 from utils import count_subtitles, rand_number, readable_size, readable_time, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
 from ytdlp.download import ytdlp_download
-from ytdlp.utils import append_subtitle, cleanup_ytdlp, get_subtitles, platform_emoji
+from ytdlp.utils import append_subtitle, cleanup_ytdlp, generate_prompt, get_subtitles, platform_emoji
 
 
 async def preview_ytdlp(
@@ -135,13 +134,13 @@ async def preview_ytdlp(
     if true(ytdlp_send_subtitle) or true(ytdlp_send_summary):
         fpath = info["audio_path"] if info["audio_path"].is_file() else info["video_path"]
         asr_engine = kwargs.get("asr_engine", "uncensored") if platform == "youtube" else ASR.DEFAULT_ENGINE
-        if sub := await get_subtitles(fpath, url, asr_engine):
+        if sub := await get_subtitles(fpath, url, asr_engine, info):
             subtitles = f"🔤<b>字幕:</b>\n{sub}"
 
     # get ai summary
     summary = ""
     if subtitles and true(ytdlp_send_summary):
-        prompt = generate_prompt(info)
+        prompt = generate_prompt(info, target="summary")
         ai_msg = Message(  # Construct a message for AI
             id=rand_number(),
             chat=message.chat,
@@ -265,23 +264,6 @@ async def generate_captions(
     return results
 
 
-def generate_prompt(info: dict) -> str:
-    """Generate prompt for AI summary."""
-    prompt = f"以上是{info['extractor'].title()}视频"
-    if author := info.get("author"):
-        prompt += f"作者【{author}】"
-    prompt += "的一期节目的文字稿。该期节目详情如下:\n"
-    if title := info.get("title"):
-        prompt += f"节目标题: {title}\n"
-    if pubdate := glom(info, Coalesce("pubdate", "upload_date"), default=""):
-        prompt += f"发布日期: {pubdate}\n"
-
-    if desc := info.get("description"):
-        prompt += f"节目简介: {desc}\n"
-    prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
-    return prompt
-
-
 def get_target_chats(message: Message, video_target: str | int | None = None, audio_target: str | int | None = None, **kwargs) -> tuple[int | str, int | str]:
     """Get target chats of video and audio messages.
 
src/ytdlp/utils.py
@@ -5,7 +5,7 @@ from pathlib import Path
 from typing import Literal
 from urllib.parse import urlparse
 
-from glom import glom
+from glom import Coalesce, glom
 from loguru import logger
 from pyrogram.enums import ParseMode
 from pyrogram.types import Message
@@ -190,15 +190,16 @@ def find_thumbnail(video_path: str | Path, audio_path: str | Path) -> str | None
     return None
 
 
-async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str) -> str:
+async def get_subtitles(audio_path: str | Path, url: str, asr_engine: str, vinfo: dict) -> str:
     # send subtitles
     subtitles = ""
     matched = await match_social_media_link(url)
+    reference = generate_prompt(vinfo, "correction")
     if matched["platform"] in ["bilibili", "youtube"]:  # get subtitle from API first
-        res = await fetch_subtitle(url=url)
+        res = await fetch_subtitle(url=url, reference=reference)
         subtitles = res.get("subtitles", "")  # only subtitles, no Bilibili's AI summary
     if not subtitles:
-        res = await asr_file(audio_path, asr_engine, silent=True)
+        res = await asr_file(audio_path, asr_engine, corrector_reference=reference, silent=True)
         subtitles = res.get("texts", "")
         if count_subtitles(subtitles) < 20:
             subtitles = ""  # ignore too  short transcription
@@ -246,6 +247,22 @@ async def append_subtitle(name: str, sent_messages: dict) -> dict:
     return modified
 
 
+def generate_prompt(info: dict, target: Literal["summary", "correction"]) -> str:
+    """Generate prompt for AI summary or correction."""
+    prompt = f"以上是{info['extractor'].title()}视频" if target == "summary" else f"本次转录稿为{info['extractor'].title()}平台"
+    if author := info.get("author"):
+        prompt += f"作者【{author}】"
+    prompt += "的一期节目的文字稿。该期节目详情如下:\n"
+    if title := info.get("title"):
+        prompt += f"节目标题: {title}\n"
+    if pubdate := glom(info, Coalesce("pubdate", "upload_date"), default=""):
+        prompt += f"发布日期: {pubdate}\n"
+    if desc := info.get("description"):
+        prompt += f"节目简介: {desc}\n"
+    prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头" if target == "summary" else ""
+    return prompt
+
+
 def cleanup_ytdlp(vid: str):
     if not vid:
         return