main
  1import json
  2import re
  3from contextlib import suppress
  4
  5from pyrogram.types import Chat, Message
  6from pyrogram.types.messages_and_media.message import Str
  7
  8from ai.main import ai_text_generation
  9from config import PREFIX
 10from utils import rand_number
 11
 12# ruff: noqa: RUF001
 13JSON_SCHEMA = {
 14    "title": "List of Correction",
 15    "type": "array",
 16    "items": {
 17        "type": "object",
 18        "title": "Correction",
 19        "properties": {
 20            "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
 21            "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
 22        },
 23        "required": ["idx", "corrected"],
 24        "additionalProperties": False,
 25    },
 26}
 27
 28
 29async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
 30    """Correct ASR results.
 31
 32    Example:
 33        [00:00] hello
 34        [00:01] world
 35
 36    Args:
 37        inputs (str): original ASR results.
 38
 39    Returns:
 40        str: corrected ASR results.
 41    """
 42    SYSTEM_PROMPT = """# 身份与职责
 43你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
 44
 45# 校对规则
 46## 必做事项
 471. 逐行检查提供的转录文稿中的每一项,识别两类错误:
 48   - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
 49   - 口语重复:无意义的重复表述(如“这个这个方案”)
 50   - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
 512. 仅保留错误项,正确项不纳入输出
 52
 53## 约束条件
 541. 保留原始文本中的emoji表情
 552. 不处理除指定三类错误之外的其他错误(如逻辑错误)
 56
 57# 输入处理
 581. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
 592. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
 603. 若输入为空或格式错误,输出空列表
 61
 62# 执行步骤
 631. 初始化空列表用于存储错误项
 642. 遍历转录稿中的每一项:
 65   a. 检查text字段是否存在转录错误
 66   b. 检查text字段是否存在口语重复
 67   c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
 68
 69# 输出规范
 701. 输出格式为JSON数组,每项包含idx和corrected两个字段
 712. 仅输出存在错误的项,正确项不显示
 723. 语言保持与原始文本一致的口语化风格
 734. 错误项数量无限制,完整呈现所有识别到的错误
 74
 75示例输入:
 76[
 77  {"idx": 0, "corrected": "平果"},
 78  {"idx": 1, "corrected": "这个这个方案"}
 79]
 80
 81示例输出:
 82[
 83  {"idx": 0, "corrected": "苹果"},
 84  {"idx": 1, "corrected": "这个方案"}
 85]
 86"""
 87    if reference:
 88        SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
 89    if not inputs:
 90        return inputs
 91    # match [mm:ss] or [hh:mm:ss]
 92    pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
 93    matches = re.findall(pattern, inputs)
 94    texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
 95    ai = await ai_text_generation(
 96        "fake-client",  # type: ignore
 97        Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
 98        openai_responses_config={
 99            "instructions": SYSTEM_PROMPT,
100            "text": {
101                "format": {
102                    "type": "json_schema",
103                    "name": "ASRCorrection",
104                    "strict": True,
105                    "description": "A list of ASR correction",
106                    "schema": JSON_SCHEMA,
107                }
108            },
109        },
110        gemini_generate_content_config={
111            "system_instruction": SYSTEM_PROMPT,
112            "responseMimeType": "application/json",
113            "responseJsonSchema": JSON_SCHEMA,
114        },
115        openai_append_tool_results=False,
116        gemini_append_grounding=False,
117        cache_response_ttl=0,
118        silent=True,
119    )
120    with suppress(Exception):
121        for output in json.loads(ai["texts"]):
122            idx = output["idx"]
123            matches[idx] = (matches[idx][0], output["corrected"])
124        return "\n".join([f"{item[0]} {item[1]}" for item in matches])
125    return inputs