Commit b918cf5

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-20 17:04:40
chore(asr): better error handling
1 parent ce6fca4
src/asr/ali_asr.py
@@ -18,7 +18,7 @@ from database import delete_alist, upload_alist, upload_uguu
 from networking import hx_req
 
 
-async def ali_asr(path: str | Path) -> str:
+async def ali_asr(path: str | Path) -> dict:
     """Create Aliyun ASR Task.
 
     录音文件识别请求
@@ -31,7 +31,7 @@ async def ali_asr(path: str | Path) -> str:
     """
     api_keys = [x.strip() for x in ASR.ALI_API_KEY.split(",") if x.strip()]
     if not api_keys:
-        return "请配置阿里云语音识别的API Key"
+        return {"error": "请配置阿里云语音识别的API Key"}
     models = [x.strip() for x in ASR.ALI_MODEL.split(",") if x.strip()]
     model = random.choice(models)
     logger.debug(f"阿里云ASR {path} via model: {model}")
@@ -47,8 +47,10 @@ async def ali_asr(path: str | Path) -> str:
         if path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
             path = downsampe_audio(path)
         url = await upload_uguu(path)  # max 100 MB for Uguu
-    else:
+    elif ASR.ALI_FS_ENGINE.lower() == "alist":
         url = await upload_alist(path)
+    else:
+        return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
 
     payload = {"model": model, "input": {"file_urls": [url]}}
     res = await hx_req(
@@ -59,11 +61,13 @@ async def ali_asr(path: str | Path) -> str:
         timeout=600,
         check_keys=["output.task_id"],
     )
+    if res.get("hx_error"):
+        return {"error": res["hx_error"]}
     logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
     return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
 
 
-async def query_ali_asr(task_id: str, api_key: str) -> str:
+async def query_ali_asr(task_id: str, api_key: str, query_times: int = 0) -> dict:
     """Query Ali ASR Task.
 
     录音文件识别结果查询
@@ -82,27 +86,32 @@ async def query_ali_asr(task_id: str, api_key: str) -> str:
         post_json=payload,
         check_keys=["output.task_status"],
     )
+    if result.get("hx_error"):
+        return {"error": result["hx_error"]}
     status = glom(result, "output.task_status")
-    query_times = 0
     while status in ["RUNNING", "PENDING"] and query_times < 600:  # max 10 minutes
         await asyncio.sleep(1)
         query_times += 1
-        logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
-        result = await query_ali_asr(task_id, api_key)
-        if isinstance(result, str):
+        logger.trace(f"Status:[{status} ({query_times}/600)], Wating TaskID: {task_id}")
+        result = await query_ali_asr(task_id, api_key, query_times)
+        if result.get("texts") or result.get("error"):
             return result
         status = glom(result, "output.task_status")
-    await clean_alist(glom(result, "output.results.0.file_url", default=""))
+    if ASR.ALI_FS_ENGINE.lower() == "alist":
+        await clean_alist(glom(result, "output.results.0.file_url", default=""))
     if status == "SUCCEEDED":
         transcription_url = glom(result, "output.results.0.transcription_url")
-        trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"])  # DO NOT use AsyncCurlTransport
+        trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"])
+        if trans_res.get("hx_error"):
+            return {"error": trans_res["hx_error"]}
+        # DO NOT use AsyncCurlTransport
         sentence_start_ms = glom(trans_res, "transcripts.0.sentences.*.begin_time")
         sentences = glom(trans_res, "transcripts.0.sentences.*.text")
         return generate_ali_transcription(sentence_start_ms, sentences)
-    return "❌" + glom(result, "output.message", default="语音识别失败")
+    return {"error": "❌" + glom(result, "output.message", default="语音识别失败")}
 
 
-def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
+def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> dict:
     # convert audio file
     sample_rate = 8000 if "8k" in model else 16000
     ext = "opus"
@@ -110,9 +119,11 @@ def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
     recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
     result = recognition.call(Path(path).as_posix())
     if result.status_code != 200:
-        return f"❌语音识别失败: {result.message}"
+        return {"error": f"❌语音识别失败: {result.message}"}
     Path(path).unlink(missing_ok=True)
     data = result.get_sentence()
+    if not data:
+        return {"error": "⚠️该音频未识别到文字"}
     start_times = flatten(glom(data, "*.words.*.begin_time"))
     texts = flatten(glom(data, "*.words.*.text"))
     punctuations = flatten(glom(data, "*.words.*.punctuation"))
@@ -120,7 +131,7 @@ def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
     return generate_ali_transcription(start_times, sentences)
 
 
-def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> str:
+def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> dict:
     def clean_tags(text: str) -> str:
         """Clean sensevoice tags.
 
@@ -131,19 +142,23 @@ def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str
         return re.sub(r"<\|.*?\|>", "", text)
 
     res = ""
-    indexs = list(range(len(sentences)))
-    for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
-        text = clean_tags(sentence)
-        if not text:
-            continue
-        if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
-            start_seconds = float(start_ms) // 1000
-            minutes = int(start_seconds // 60)
-            seconds = int(start_seconds % 60)
-            res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
-        else:
-            res += text
-    return res.strip()
+    try:
+        indexs = list(range(len(sentences)))
+        for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
+            text = clean_tags(sentence)
+            if not text:
+                continue
+            if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
+                start_seconds = float(start_ms) // 1000
+                minutes = int(start_seconds // 60)
+                seconds = int(start_seconds % 60)
+                res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
+            else:
+                res += text
+    except Exception as e:
+        logger.error(e)
+        return {"error": str(e)}
+    return {"texts": res.strip()}
 
 
 async def clean_alist(url: str):
src/asr/deepgram.py
@@ -12,14 +12,14 @@ from networking import hx_req
 from utils import zhcn
 
 
-async def deepgram_asr(path: str | Path) -> str:
+async def deepgram_asr(path: str | Path) -> dict:
     """Deepgram ASR.
 
     https://developers.deepgram.com/docs/pre-recorded-audio
     """
     api_keys = [x.strip() for x in ASR.DEEPGRAM_API.split(",") if x.strip()]
     if not api_keys:
-        return "请配置DeepGram语音识别的API Key"
+        return {"error": "请配置DeepGram语音识别的API Key"}
     logger.debug(f"DeepGram ASR {path}")
     headers = {"Authorization": f"Token {random.choice(api_keys)}"}
     path = Path(path).expanduser().resolve()
@@ -35,18 +35,24 @@ async def deepgram_asr(path: str | Path) -> str:
             timeout=600,
             check_keys=["results.channels.0.alternatives.0.words"],
         )
-    start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
-    sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
-    res = ""
-    indexs = list(range(len(sentences)))
-    for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
-        if not sentence:
-            continue
-        if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
-            start_seconds = float(start_time)
-            minutes = int(start_seconds // 60)
-            seconds = int(start_seconds % 60)
-            res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
-        else:
-            res += sentence
-    return zhcn(res.strip())
+        if res.get("hx_error"):
+            return {"error": res["hx_error"]}
+    try:
+        start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
+        sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
+        res = ""
+        indexs = list(range(len(sentences)))
+        for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
+            if not sentence:
+                continue
+            if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
+                start_seconds = float(start_time)
+                minutes = int(start_seconds // 60)
+                seconds = int(start_seconds % 60)
+                res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+            else:
+                res += sentence
+    except Exception as e:
+        logger.error(e)
+        return {"error": str(e)}
+    return {"texts": zhcn(res.strip())}
src/asr/gemini_asr.py
@@ -18,6 +18,7 @@ from llm.gemini import gemini_stream
 from llm.hooks import hook_gemini_httpoptions
 from llm.utils import shuffle_keys
 from messages.progress import modify_progress
+from utils import count_subtitles
 
 
 async def gemini_stream_asr(
@@ -69,6 +70,7 @@ Notes:
 - Focus on accuracy in capturing both the timing and the spoken content.
 - Maintain consistent formatting to ensure clarity and readability."""
     path = Path(path)
+    sent_messages = []
     status = None if silent else kwargs.get("progress")
     api_keys = shuffle_keys(GEMINI.API_KEY)
     if model_id is None:
@@ -104,7 +106,7 @@ Notes:
                 append_grounding=False,
                 **kwargs,
             )
-            if res.get("error"):
+            if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
                 continue
             sent_messages = res.get("sent_messages", [])
             break
src/asr/tecent_asr.py
@@ -12,7 +12,7 @@ from loguru import logger
 
 from asr.utils import downsampe_audio
 from config import ASR, FILE_SERVER
-from database import upload_alist, upload_uguu
+from database import delete_alist, upload_alist, upload_uguu
 from networking import hx_req
 from utils import nowdt
 
@@ -70,7 +70,7 @@ def generate_tencent_cloud_headers(
     }
 
 
-async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> dict:
     """Tencent Single Sentence ASR.
 
     一句话识别
@@ -91,11 +91,14 @@ async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -
         check_keys=["Response.WordList"],
     )
     if res["Response"]["WordList"] is None:
-        return "⚠️该音频未识别到文字"
+        return {"error": "⚠️该音频未识别到文字"}
+    if res.get("hx_error"):
+        return {"error": res["hx_error"]}
+
     return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
 
 
-async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
     """Tencent Flash ASR.
 
     录音文件识别极速版
@@ -131,14 +134,14 @@ async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) ->
         )
         if error := res.get("hx_error", ""):
             if "audio data empty" in error:
-                return "⚠️该音频未识别到文字"
-            return error
+                return {"error": "⚠️该音频未识别到文字"}
+            return {"error": error}
         sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
         words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
         return generate_tencent_transcription(sentence_start_ms, words)
 
 
-async def tencent_async_asr(path: str | Path, engine: str) -> str:
+async def tencent_async_asr(path: str | Path, engine: str) -> dict:
     """Create Tencent ASR Task.
 
     录音文件识别请求
@@ -151,8 +154,10 @@ async def tencent_async_asr(path: str | Path, engine: str) -> str:
         if path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
             path = downsampe_audio(path)
         url = await upload_uguu(path)  # max 100 MB for Uguu
-    else:
+    elif ASR.TENCENT_FS_ENGINE.lower() == "alist":
         url = await upload_alist(path)
+    else:
+        return {"error": f"Unsupported file server engine: {ASR.TENCENT_FS_ENGINE}"}
 
     payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
     headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
@@ -165,12 +170,14 @@ async def tencent_async_asr(path: str | Path, engine: str) -> str:
         proxy=ASR.TENCENT_PROXY,
         check_keys=["Response.Data.TaskId"],
     )
+    if resp.get("hx_error"):
+        return {"error": resp["hx_error"]}
     task_id = resp["Response"]["Data"]["TaskId"]
     logger.success(f"ASR任务提交成功, TaskID: {task_id}")
-    return await tencent_query_asr(task_id)
+    return await tencent_query_asr(task_id, file_name=path.name)
 
 
-async def tencent_query_asr(task_id: int) -> str:
+async def tencent_query_asr(task_id: int, file_name: str, query_times: int = 0) -> dict:
     """Query Tencent ASR Task.
 
     录音文件识别结果查询
@@ -187,35 +194,44 @@ async def tencent_query_asr(task_id: int) -> str:
         proxy=ASR.TENCENT_PROXY,
         check_keys=["Response.Data.StatusStr"],
     )
+    if result.get("hx_error"):
+        return {"error": result["hx_error"]}
     status = glom(result, "Response.Data.StatusStr")
-    query_times = 0
     while status in ["waiting", "doing"] and query_times < 600:  # max 10 minutes
         await asyncio.sleep(1)
         query_times += 1
-        logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
-        result = await tencent_query_asr(task_id)
-        if isinstance(result, str):
+        logger.trace(f"Status: [{status} ({query_times}/600)], Wating TaskID: {task_id}")
+        result = await tencent_query_asr(task_id, file_name, query_times)
+        if result.get("texts") or result.get("error"):
             return result
         status = glom(result, "Response.Data.StatusStr")
+    if ASR.TENCENT_FS_ENGINE.lower() == "alist":
+        await delete_alist(file_name)
     if status == "success":
+        if glom(result, "Response.Data.ResultDetail") is None:
+            return {"error": "⚠️该音频未识别到文字"}
         sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
         words = glom(result, "Response.Data.ResultDetail.*.Words")
         return generate_tencent_transcription(sentence_start_ms, words)
-    return "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")
+    return {"error": "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")}
 
 
-def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> str:
+def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> dict:
     res = ""
-    for start_offset, items in zip(sentence_start_ms, words, strict=True):
-        for idx, item in enumerate(items):
-            sentence = glom(item, Coalesce("Word", "word"), default="")
-            if not sentence:
-                continue
-            if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
-                start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
-                minutes = int(start_seconds // 60)
-                seconds = int(start_seconds % 60)
-                res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
-            else:
-                res += sentence
-    return res.strip()
+    try:
+        for start_offset, items in zip(sentence_start_ms, words, strict=True):
+            for idx, item in enumerate(items):
+                sentence = glom(item, Coalesce("Word", "word"), default="")
+                if not sentence:
+                    continue
+                if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
+                    start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
+                    minutes = int(start_seconds // 60)
+                    seconds = int(start_seconds % 60)
+                    res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+                else:
+                    res += sentence
+    except Exception as e:
+        logger.error(e)
+        return {"error": str(e)}
+    return {"texts": res.strip()}
src/asr/voice_recognition.py
@@ -175,7 +175,6 @@ async def asr_file(
     **kwargs,
 ) -> dict:
     """Get ASR results of an audio file."""
-    res = {}
     path = Path(path).expanduser().resolve()
     if not path.is_file():
         return {"error": f"{path} is not exist"}
@@ -205,19 +204,23 @@ async def asr_file(
 
     logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
     try:
+        res = {}
         if asr_method == "tencent_single_asr":
-            res["texts"] = await tencent_single_asr(path, language, voice_format)
+            res = await tencent_single_asr(path, language, voice_format)
         elif asr_method == "tencent_flash_asr":
-            res["texts"] = await tencent_flash_asr(path, language, voice_format)
+            res = await tencent_flash_asr(path, language, voice_format)
         elif asr_method == "tencent_async_asr":
-            res["texts"] = await tencent_async_asr(path, language)
+            res = await tencent_async_asr(path, language)
         elif asr_method == "ali":
-            res["texts"] = await ali_asr(path)
+            res = await ali_asr(path)
         elif asr_method == "deepgram":
-            res["texts"] = await deepgram_asr(path)
+            res = await deepgram_asr(path)
         elif asr_method == "gemini":
-            res |= await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
-        logger.success(f"{res['texts']!r}")
+            res = await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
+        else:
+            return {"error": "ASR method not supported"}
+        if res.get("texts"):
+            logger.success(f"{res['texts']!r}")
     except Exception as e:
         error = f"Failed to recognize audio: {e}"
         logger.error(error)
src/preview/ytdlp.py
@@ -263,6 +263,8 @@ async def preview_ytdlp(
             ytdlp_transcription_engine = "gemini" if "youtube" in info["extractor"] else ytdlp_transcription_engine  # use gemini to bypass censorship
             res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, silent=True)
             subtitles = res.get("texts", "")
+            if count_subtitles(subtitles) < 20:
+                subtitles = ""  # ignore too  short transcription
         if subtitles:
             if len(subtitles) > TEXT_LENGTH or transcription_force_file:
                 caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}"