Commit de190f9
Changed files (2)
src
src/asr/corrector.py
@@ -7,33 +7,9 @@ from pyrogram.types.messages_and_media.message import Str
from ai.main import ai_text_generation
from config import PREFIX
+from schema import TranscriptionCorrection, get_schema
from utils import rand_number
-# ruff: noqa: RUF001
-JSON_SCHEMA = {
- "title": "List of Correction",
- "type": "array",
- "items": {
- "type": "object",
- "title": "Correction",
- "description": "A list of transcription correction, only include the correction that is not correct.",
- "properties": {
- "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
- "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
- },
- "required": ["idx", "corrected"],
- "additionalProperties": False,
- },
-}
-
-JSON_SCHEMA_CHAT = { # OpenAI Chat Completions API requires that the root JSON Schema has type: "object".
- "title": "Transcription Correction",
- "type": "object",
- "required": ["corrections"],
- "additionalProperties": False,
- "properties": {"corrections": JSON_SCHEMA},
-}
-
async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
"""Correct ASR results.
@@ -66,7 +42,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
# 输入处理
1. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
2. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
-3. 若输入为空或格式错误,输出空列表
# 执行步骤
1. 初始化空列表用于存储错误项
@@ -76,10 +51,11 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
# 输出规范
-1. 输出格式为JSON数组,每项包含idx和corrected两个字段
-2. 仅输出存在错误的项,正确项不显示
-3. 语言保持与原始文本一致的口语化风格
-4. 错误项数量无限制,完整呈现所有识别到的错误
+1. 输出格式为JSON对象,格式为{"corrections": []},其中corrections是一个数组,数组每项包含idx和corrected两个字段
+2. 若输入为空或格式错误,输出corrections为空数组
+3. 仅输出存在错误的项,正确项不显示
+4. 语言保持与原始文本一致的口语化风格
+5. 错误项数量无限制,完整呈现所有识别到的错误
示例输入:
[
@@ -88,10 +64,11 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
]
示例输出:
-[
+{"corrections": [
{"idx": 0, "corrected": "苹果"},
{"idx": 1, "corrected": "这个方案"}
-]
+]}
+
"""
if reference:
SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
@@ -111,15 +88,14 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
"type": "json_schema",
"name": "TranscriptionCorrection",
"strict": True,
- "description": "A list of transcription correction",
- "schema": JSON_SCHEMA,
+ "schema": get_schema("transcription_correction"),
}
},
},
gemini_generate_content_config={
"system_instruction": SYSTEM_PROMPT,
"responseMimeType": "application/json",
- "responseJsonSchema": JSON_SCHEMA,
+ "responseJsonSchema": get_schema("transcription_correction"),
},
openai_system_prompt=SYSTEM_PROMPT,
openai_completions_config={
@@ -128,7 +104,7 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
"strict": True,
"json_schema": {
"name": "TranscriptionCorrection",
- "schema": JSON_SCHEMA_CHAT,
+ "schema": get_schema("transcription_correction"),
"strict": True,
},
}
@@ -140,11 +116,9 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
silent=True,
)
with suppress(Exception):
- corrections = json.loads(ai["texts"])
- if not isinstance(corrections, list): # response of chat completions
- corrections = corrections["corrections"]
- for output in corrections:
- idx = output["idx"]
- matches[idx] = (matches[idx][0], output["corrected"])
+ corrections = TranscriptionCorrection.model_validate_json(ai["texts"])
+ for output in corrections.corrections:
+ idx = output.idx
+ matches[idx] = (matches[idx][0], output.corrected)
return "\n".join([f"{item[0]} {item[1]}" for item in matches])
return inputs
src/schema.py
@@ -75,11 +75,36 @@ class AIPage(BaseModel):
mermaid_url: str | None = Field(default=None, description="思维导图代码URL")
-def get_schema(name: Literal["content_extraction"] = "content_extraction") -> dict:
+class Correction(BaseModel):
+ """转录错误修正项.
+
+ 只包含转录错误项,正确项不显示。
+ """
+
+ model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
+ idx: int = Field(title="索引", description="修正项在原始转录稿中的索引")
+ corrected: str = Field(description="修正后的文本")
+
+
+class TranscriptionCorrection(BaseModel):
+ """转录错误修正."""
+
+ model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
+ corrections: list[Correction] = Field(title="转录错误修正", description="转录错误修正项列表")
+
+
+def get_schema(name: Literal["content_extraction", "transcription_correction"] = "content_extraction") -> dict:
if name == "content_extraction":
schema = ContentExtraction.model_json_schema()
+ elif name == "transcription_correction":
+ schema = TranscriptionCorrection.model_json_schema()
else:
return {}
inlined_schema = jsonref.replace_refs(schema, proxies=False)
inlined_schema.pop("$defs", None)
return inlined_schema
+
+
+if __name__ == "__main__":
+ print(get_schema("content_extraction"))
+ print(get_schema("transcription_correction"))