Commit b133a33
Changed files (1)
src
asr
src/asr/corrector.py
@@ -16,6 +16,7 @@ JSON_SCHEMA = {
"items": {
"type": "object",
"title": "Correction",
+ "description": "A list of transcription correction, only include the correction that is not correct.",
"properties": {
"idx": {"description": "Index of the transcription item", "title": "Index", "type": "integer"},
"corrected": {"description": "Corrected text", "title": "Corrected", "type": "string"},
@@ -25,6 +26,14 @@ JSON_SCHEMA = {
},
}
+JSON_SCHEMA_CHAT = { # OpenAI Chat Completions API requires that the root JSON Schema has type: "object".
+ "title": "Transcription Correction",
+ "type": "object",
+ "required": ["corrections"],
+ "additionalProperties": False,
+ "properties": {"corrections": JSON_SCHEMA},
+}
+
async def asr_corrector(inputs: str, reference: str | None = None, corrector_model: str = "asr-corrector") -> str:
"""Correct ASR results.
@@ -100,9 +109,9 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
"text": {
"format": {
"type": "json_schema",
- "name": "ASRCorrection",
+ "name": "TranscriptionCorrection",
"strict": True,
- "description": "A list of ASR correction",
+ "description": "A list of transcription correction",
"schema": JSON_SCHEMA,
}
},
@@ -112,13 +121,29 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
"responseMimeType": "application/json",
"responseJsonSchema": JSON_SCHEMA,
},
+ openai_system_prompt=SYSTEM_PROMPT,
+ openai_completions_config={
+ "response_format": {
+ "type": "json_schema",
+ "strict": True,
+ "json_schema": {
+ "name": "TranscriptionCorrection",
+ "schema": JSON_SCHEMA_CHAT,
+ "strict": True,
+ },
+ }
+ },
+ openai_enable_tool_call=False,
openai_append_tool_results=False,
gemini_append_grounding=False,
cache_response_ttl=0,
silent=True,
)
with suppress(Exception):
- for output in json.loads(ai["texts"]):
+ corrections = json.loads(ai["texts"])
+ if not isinstance(corrections, list): # response of chat completions
+ corrections = corrections["corrections"]
+ for output in corrections:
idx = output["idx"]
matches[idx] = (matches[idx][0], output["corrected"])
return "\n".join([f"{item[0]} {item[1]}" for item in matches])