main
  1import json
  2import re
  3from contextlib import suppress
  4
  5from pyrogram.types import Chat, Message
  6from pyrogram.types.messages_and_media.message import Str
  7
  8from ai.main import ai_text_generation
  9from config import PREFIX
 10from schema import TranscriptionCorrection, get_schema
 11from utils import rand_number
 12
 13
 14async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
 15    """Correct ASR results.
 16
 17    Example:
 18        [00:00] hello
 19        [00:01] world
 20
 21    Args:
 22        inputs (str): original ASR results.
 23
 24    Returns:
 25        str: corrected ASR results.
 26    """
 27    SYSTEM_PROMPT = """# 身份与职责
 28你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
 29
 30# 校对规则
 31## 必做事项
 321. 逐行检查提供的转录文稿中的每一项,识别两类错误:
 33   - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
 34   - 口语重复:无意义的重复表述(如“这个这个方案”)
 35   - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
 362. 仅保留错误项,正确项不纳入输出
 37
 38## 约束条件
 391. 保留原始文本中的emoji表情
 402. 不处理除指定三类错误之外的其他错误(如逻辑错误)
 41
 42# 输入处理
 431. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
 442. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
 45
 46# 执行步骤
 471. 初始化空列表用于存储错误项
 482. 遍历转录稿中的每一项:
 49   a. 检查text字段是否存在转录错误
 50   b. 检查text字段是否存在口语重复
 51   c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
 52
 53# 输出规范
 541. 输出格式为JSON对象,格式为{"corrections": []},其中corrections是一个数组,数组每项包含idx和corrected两个字段
 552. 若输入为空或格式错误,输出corrections为空数组
 563. 仅输出存在错误的项,正确项不显示
 574. 语言保持与原始文本一致的口语化风格
 585. 错误项数量无限制,完整呈现所有识别到的错误
 59
 60示例输入:
 61[
 62  {"idx": 0, "corrected": "平果"},
 63  {"idx": 1, "corrected": "这个这个方案"}
 64]
 65
 66示例输出:
 67{"corrections": [
 68  {"idx": 0, "corrected": "苹果"},
 69  {"idx": 1, "corrected": "这个方案"}
 70]}
 71
 72"""
 73    if reference:
 74        SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
 75    if not inputs:
 76        return inputs
 77    # match [mm:ss] or [hh:mm:ss]
 78    pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
 79    matches = re.findall(pattern, inputs)
 80    texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
 81    ai = await ai_text_generation(
 82        "fake-client",  # type: ignore
 83        Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
 84        openai_responses_config={
 85            "instructions": SYSTEM_PROMPT,
 86            "text": {
 87                "format": {
 88                    "type": "json_schema",
 89                    "name": "TranscriptionCorrection",
 90                    "strict": True,
 91                    "schema": get_schema("transcription_correction"),
 92                }
 93            },
 94        },
 95        gemini_generate_content_config={
 96            "system_instruction": SYSTEM_PROMPT,
 97            "responseMimeType": "application/json",
 98            "responseJsonSchema": get_schema("transcription_correction"),
 99        },
100        openai_system_prompt=SYSTEM_PROMPT,
101        openai_completions_config={
102            "response_format": {
103                "type": "json_schema",
104                "strict": True,
105                "json_schema": {
106                    "name": "TranscriptionCorrection",
107                    "schema": get_schema("transcription_correction"),
108                    "strict": True,
109                },
110            }
111        },
112        openai_enable_tool_call=False,
113        openai_append_tool_results=False,
114        gemini_append_grounding=False,
115        cache_response_ttl=0,
116        silent=True,
117    )
118    with suppress(Exception):
119        corrections = TranscriptionCorrection.model_validate_json(ai["texts"])
120        for output in corrections.corrections:
121            idx = output.idx
122            matches[idx] = (matches[idx][0], output.corrected)
123        return "\n".join([f"{item[0]} {item[1]}" for item in matches])
124    return inputs