main
1import json
2import re
3from contextlib import suppress
4
5from pyrogram.types import Chat, Message
6from pyrogram.types.messages_and_media.message import Str
7
8from ai.main import ai_text_generation
9from config import PREFIX
10from schema import TranscriptionCorrection, get_schema
11from utils import rand_number
12
13
14async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
15 """Correct ASR results.
16
17 Example:
18 [00:00] hello
19 [00:01] world
20
21 Args:
22 inputs (str): original ASR results.
23
24 Returns:
25 str: corrected ASR results.
26 """
27 SYSTEM_PROMPT = """# 身份与职责
28你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
29
30# 校对规则
31## 必做事项
321. 逐行检查提供的转录文稿中的每一项,识别两类错误:
33 - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
34 - 口语重复:无意义的重复表述(如“这个这个方案”)
35 - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
362. 仅保留错误项,正确项不纳入输出
37
38## 约束条件
391. 保留原始文本中的emoji表情
402. 不处理除指定三类错误之外的其他错误(如逻辑错误)
41
42# 输入处理
431. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
442. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
45
46# 执行步骤
471. 初始化空列表用于存储错误项
482. 遍历转录稿中的每一项:
49 a. 检查text字段是否存在转录错误
50 b. 检查text字段是否存在口语重复
51 c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
52
53# 输出规范
541. 输出格式为JSON对象,格式为{"corrections": []},其中corrections是一个数组,数组每项包含idx和corrected两个字段
552. 若输入为空或格式错误,输出corrections为空数组
563. 仅输出存在错误的项,正确项不显示
574. 语言保持与原始文本一致的口语化风格
585. 错误项数量无限制,完整呈现所有识别到的错误
59
60示例输入:
61[
62 {"idx": 0, "corrected": "平果"},
63 {"idx": 1, "corrected": "这个这个方案"}
64]
65
66示例输出:
67{"corrections": [
68 {"idx": 0, "corrected": "苹果"},
69 {"idx": 1, "corrected": "这个方案"}
70]}
71
72"""
73 if reference:
74 SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
75 if not inputs:
76 return inputs
77 # match [mm:ss] or [hh:mm:ss]
78 pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
79 matches = re.findall(pattern, inputs)
80 texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
81 ai = await ai_text_generation(
82 "fake-client", # type: ignore
83 Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
84 openai_responses_config={
85 "instructions": SYSTEM_PROMPT,
86 "text": {
87 "format": {
88 "type": "json_schema",
89 "name": "TranscriptionCorrection",
90 "strict": True,
91 "schema": get_schema("transcription_correction"),
92 }
93 },
94 },
95 gemini_generate_content_config={
96 "system_instruction": SYSTEM_PROMPT,
97 "responseMimeType": "application/json",
98 "responseJsonSchema": get_schema("transcription_correction"),
99 },
100 openai_system_prompt=SYSTEM_PROMPT,
101 openai_completions_config={
102 "response_format": {
103 "type": "json_schema",
104 "strict": True,
105 "json_schema": {
106 "name": "TranscriptionCorrection",
107 "schema": get_schema("transcription_correction"),
108 "strict": True,
109 },
110 }
111 },
112 openai_enable_tool_call=False,
113 openai_append_tool_results=False,
114 gemini_append_grounding=False,
115 cache_response_ttl=0,
116 silent=True,
117 )
118 with suppress(Exception):
119 corrections = TranscriptionCorrection.model_validate_json(ai["texts"])
120 for output in corrections.corrections:
121 idx = output.idx
122 matches[idx] = (matches[idx][0], output.corrected)
123 return "\n".join([f"{item[0]} {item[1]}" for item in matches])
124 return inputs