Commit de190f9

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-06-08 07:00:27
chore(asr): improve ASR schema
1 parent f97976a
Changed files (2)
src/asr/corrector.py
@@ -7,33 +7,9 @@ from pyrogram.types.messages_and_media.message import Str
 
 from ai.main import ai_text_generation
 from config import PREFIX
+from schema import TranscriptionCorrection, get_schema
 from utils import rand_number
 
-# ruff: noqa: RUF001
-JSON_SCHEMA = {
-    "title": "List of Correction",
-    "type": "array",
-    "items": {
-        "type": "object",
-        "title": "Correction",
-        "description": "A list of transcription correction, only include the correction that is not correct.",
-        "properties": {
-            "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
-            "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
-        },
-        "required": ["idx", "corrected"],
-        "additionalProperties": False,
-    },
-}
-
-JSON_SCHEMA_CHAT = {  # OpenAI Chat Completions API requires that the root JSON Schema has type: "object".
-    "title": "Transcription Correction",
-    "type": "object",
-    "required": ["corrections"],
-    "additionalProperties": False,
-    "properties": {"corrections": JSON_SCHEMA},
-}
-
 
 async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
     """Correct ASR results.
@@ -66,7 +42,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
 # 输入处理
 1. 优先读取用户提供的转录稿中的内容,格式为JSON数组,每个项包含idx和text两个字段
 2. 若提供<reference>{{reference}}</reference>,可作为错误判断的辅助参考(如专业术语、专有名词)
-3. 若输入为空或格式错误,输出空列表
 
 # 执行步骤
 1. 初始化空列表用于存储错误项
@@ -76,10 +51,11 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
    c. 若存在任意一种错误,将修改后的结果加入到输出列表中,格式为{"idx": int, "corrected": str}
 
 # 输出规范
-1. 输出格式为JSON数组,每项包含idx和corrected两个字段
-2. 仅输出存在错误的项,正确项不显示
-3. 语言保持与原始文本一致的口语化风格
-4. 错误项数量无限制,完整呈现所有识别到的错误
+1. 输出格式为JSON对象,格式为{"corrections": []},其中corrections是一个数组,数组每项包含idx和corrected两个字段
+2. 若输入为空或格式错误,输出corrections为空数组
+3. 仅输出存在错误的项,正确项不显示
+4. 语言保持与原始文本一致的口语化风格
+5. 错误项数量无限制,完整呈现所有识别到的错误
 
 示例输入:
 [
@@ -88,10 +64,11 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
 ]
 
 示例输出:
-[
+{"corrections": [
   {"idx": 0, "corrected": "苹果"},
   {"idx": 1, "corrected": "这个方案"}
-]
+]}
+
 """
     if reference:
         SYSTEM_PROMPT += f"\n<reference>{reference}</reference>"
@@ -111,15 +88,14 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
                     "type": "json_schema",
                     "name": "TranscriptionCorrection",
                     "strict": True,
-                    "description": "A list of transcription correction",
-                    "schema": JSON_SCHEMA,
+                    "schema": get_schema("transcription_correction"),
                 }
             },
         },
         gemini_generate_content_config={
             "system_instruction": SYSTEM_PROMPT,
             "responseMimeType": "application/json",
-            "responseJsonSchema": JSON_SCHEMA,
+            "responseJsonSchema": get_schema("transcription_correction"),
         },
         openai_system_prompt=SYSTEM_PROMPT,
         openai_completions_config={
@@ -128,7 +104,7 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
                 "strict": True,
                 "json_schema": {
                     "name": "TranscriptionCorrection",
-                    "schema": JSON_SCHEMA_CHAT,
+                    "schema": get_schema("transcription_correction"),
                     "strict": True,
                 },
             }
@@ -140,11 +116,9 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
         silent=True,
     )
     with suppress(Exception):
-        corrections = json.loads(ai["texts"])
-        if not isinstance(corrections, list):  # response of chat completions
-            corrections = corrections["corrections"]
-        for output in corrections:
-            idx = output["idx"]
-            matches[idx] = (matches[idx][0], output["corrected"])
+        corrections = TranscriptionCorrection.model_validate_json(ai["texts"])
+        for output in corrections.corrections:
+            idx = output.idx
+            matches[idx] = (matches[idx][0], output.corrected)
         return "\n".join([f"{item[0]} {item[1]}" for item in matches])
     return inputs
src/schema.py
@@ -75,11 +75,36 @@ class AIPage(BaseModel):
     mermaid_url: str | None = Field(default=None, description="思维导图代码URL")
 
 
-def get_schema(name: Literal["content_extraction"] = "content_extraction") -> dict:
+class Correction(BaseModel):
+    """转录错误修正项.
+
+    只包含转录错误项,正确项不显示。
+    """
+
+    model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
+    idx: int = Field(title="索引", description="修正项在原始转录稿中的索引")
+    corrected: str = Field(description="修正后的文本")
+
+
+class TranscriptionCorrection(BaseModel):
+    """转录错误修正."""
+
+    model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
+    corrections: list[Correction] = Field(title="转录错误修正", description="转录错误修正项列表")
+
+
+def get_schema(name: Literal["content_extraction", "transcription_correction"] = "content_extraction") -> dict:
     if name == "content_extraction":
         schema = ContentExtraction.model_json_schema()
+    elif name == "transcription_correction":
+        schema = TranscriptionCorrection.model_json_schema()
     else:
         return {}
     inlined_schema = jsonref.replace_refs(schema, proxies=False)
     inlined_schema.pop("$defs", None)
     return inlined_schema
+
+
+if __name__ == "__main__":
+    print(get_schema("content_extraction"))
+    print(get_schema("transcription_correction"))