Commit b133a33

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-06-02 09:50:40
feat(asr): support OpenAI Chat Completions API for ASR correction
1 parent ee1ad36
Changed files (1)
src
src/asr/corrector.py
@@ -16,6 +16,7 @@ JSON_SCHEMA = {
     "items": {
         "type": "object",
         "title": "Correction",
+        "description": "A list of transcription correction, only include the correction that is not correct.",
         "properties": {
             "idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
             "corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
@@ -25,6 +26,14 @@ JSON_SCHEMA = {
     },
 }
 
+JSON_SCHEMA_CHAT = {  # OpenAI Chat Completions API requires that the root JSON Schema has type: "object".
+    "title": "Transcription Correction",
+    "type": "object",
+    "required": ["corrections"],
+    "additionalProperties": False,
+    "properties": {"corrections": JSON_SCHEMA},
+}
+
 
 async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
     """Correct ASR results.
@@ -100,9 +109,9 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
             "text": {
                 "format": {
                     "type": "json_schema",
-                    "name": "ASRCorrection",
+                    "name": "TranscriptionCorrection",
                     "strict": True,
-                    "description": "A list of ASR correction",
+                    "description": "A list of transcription correction",
                     "schema": JSON_SCHEMA,
                 }
             },
@@ -112,13 +121,29 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
             "responseMimeType": "application/json",
             "responseJsonSchema": JSON_SCHEMA,
         },
+        openai_system_prompt=SYSTEM_PROMPT,
+        openai_completions_config={
+            "response_format": {
+                "type": "json_schema",
+                "strict": True,
+                "json_schema": {
+                    "name": "TranscriptionCorrection",
+                    "schema": JSON_SCHEMA_CHAT,
+                    "strict": True,
+                },
+            }
+        },
+        openai_enable_tool_call=False,
         openai_append_tool_results=False,
         gemini_append_grounding=False,
         cache_response_ttl=0,
         silent=True,
     )
     with suppress(Exception):
-        for output in json.loads(ai["texts"]):
+        corrections = json.loads(ai["texts"])
+        if not isinstance(corrections, list):  # response of chat completions
+            corrections = corrections["corrections"]
+        for output in corrections:
             idx = output["idx"]
             matches[idx] = (matches[idx][0], output["corrected"])
         return "\n".join([f"{item[0]} {item[1]}" for item in matches])