main
1import json
2import re
3from contextlib import suppress
4
5from pyrogram.types import Chat, Message
6from pyrogram.types.messages_and_media.message import Str
7
8from ai.main import ai_text_generation
9from config import PREFIX
10from utils import rand_number
11
12# ruff: noqa: RUF001
13JSON_SCHEMA = {
14 "title": "List of Correction",
15 "type": "array",
16 "items": {
17 "type": "object",
18 "title": "Correction",
19 "properties": {
20 "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
21 "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
22 },
23 "required": ["idx", "corrected"],
24 "additionalProperties": False,
25 },
26}
27
28
29async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
30 """Correct ASR results.
31
32 Example:
33 [00:00] hello
34 [00:01] world
35
36 Args:
37 inputs (str): original ASR results.
38
39 Returns:
40 str: corrected ASR results.
41 """
42 SYSTEM_PROMPT = """# 身份与职责
43你是专注于ASR转录稿校对的专业助手,服务于需要精准文本转化的用户,核心职责是识别并修正转录稿中的特定错误类型,确保输出内容准确反映原始语音信息。
44
45# 校对规则
46## 必做事项
471. 逐行检查提供的转录文稿中的每一项,识别两类错误:
48 - 转录错误:语音内容被错误转换(如“苹果”转成“平果”)
49 - 口语重复:无意义的重复表述(如“这个这个方案”)
50 - 标点错误:标点符号缺失或错误(如“是吗?”转成“是吗。”)
512. 仅保留错误项,正确项不纳入输出
52
53## 约束条件
541. 保留原始文本中的emoji表情
552. 不处理除指定三类错误之外的其他错误(如逻辑错误)
56
57# 输入处理
581. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
592. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
603. 若输入为空或格式错误,输出空列表
61
62# 执行步骤
631. 初始化空列表用于存储错误项
642. 遍历转录稿中的每一项:
65 a. 检查text字段是否存在转录错误
66 b. 检查text字段是否存在口语重复
67 c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
68
69# 输出规范
701. 输出格式为JSON数组,每项包含idx和corrected两个字段
712. 仅输出存在错误的项,正确项不显示
723. 语言保持与原始文本一致的口语化风格
734. 错误项数量无限制,完整呈现所有识别到的错误
74
75示例输入:
76[
77 {"idx": 0, "corrected": "平果"},
78 {"idx": 1, "corrected": "这个这个方案"}
79]
80
81示例输出:
82[
83 {"idx": 0, "corrected": "苹果"},
84 {"idx": 1, "corrected": "这个方案"}
85]
86"""
87 if reference:
88 SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
89 if not inputs:
90 return inputs
91 # match [mm:ss] or [hh:mm:ss]
92 pattern = r"(\[(?:\d{2}:)?\d{2}:\d{2}\])\s*(.*)"
93 matches = re.findall(pattern, inputs)
94 texts = json.dumps([{"idx": idx, "text": item[1]} for idx, item in enumerate(matches)], ensure_ascii=False)
95 ai = await ai_text_generation(
96 "fake-client", # type: ignore
97 Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
98 openai_responses_config={
99 "instructions": SYSTEM_PROMPT,
100 "text": {
101 "format": {
102 "type": "json_schema",
103 "name": "ASRCorrection",
104 "strict": True,
105 "description": "A list of ASR correction",
106 "schema": JSON_SCHEMA,
107 }
108 },
109 },
110 gemini_generate_content_config={
111 "system_instruction": SYSTEM_PROMPT,
112 "responseMimeType": "application/json",
113 "responseJsonSchema": JSON_SCHEMA,
114 },
115 openai_append_tool_results=False,
116 gemini_append_grounding=False,
117 cache_response_ttl=0,
118 silent=True,
119 )
120 with suppress(Exception):
121 for output in json.loads(ai["texts"]):
122 idx = output["idx"]
123 matches[idx] = (matches[idx][0], output["corrected"])
124 return "\n".join([f"{item[0]} {item[1]}" for item in matches])
125 return inputs