Commit d66d00b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-27 08:12:10
chore(asr): better logging
1 parent ac090aa
src/asr/tecent_asr.py
@@ -7,11 +7,31 @@ from pathlib import Path
 
 import anyio
 
-from config import PROXY, TOKEN
+from config import ASR_MAX_DURATION, FILE_SERVER, PROXY, TOKEN
 from networking import hx_req
 from utils import nowdt
 
 
+def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
+    """Get tencent ASR method and supported file types."""
+    if duration > ASR_MAX_DURATION:
+        return f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
+
+    asr_method = ""
+    if duration < 60 and file_size < 3 * 1024 * 1024:
+        asr_method = "single_sentence_asr"  # 一句话识别
+        supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+    elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
+        asr_method = "flash_asr"  # 录音文件识别极速版
+        supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+    elif FILE_SERVER:
+        asr_method = "async_asr"  # 录音文件识别 (异步请求)
+        supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
+    elif not FILE_SERVER:
+        return "音频过长, 需使用音频URL格式调用ASR\n请联系管理员配置`FILE_SERVER`变量", []
+    return asr_method, supported_ext
+
+
 def sign(key, msg):
     return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
 
src/asr/voice_recognition.py
@@ -11,14 +11,14 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, FILE_SERVER, PREFIX
+from asr.tecent_asr import create_async_asr, flash_asr, get_asr_method, query_async_asr, single_sentence_asr
+from config import CAPTION_LENGTH, FILE_SERVER, PREFIX
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
 from multimedia import convert_to_audio
-from utils import to_int
+from utils import rand_string, to_int
 
 # ruff: noqa: RUF001
 
@@ -115,56 +115,42 @@ async def voice_to_text(
     if asr_engine not in ENGINE_MAP:
         await modify_progress(text=f"Unsupported ASR engine: {asr_engine}", force_update=True, **kwargs)
         return
-    voice_format = ""
-    path: str | Path = await trigger_message.download()  # type: ignore
-    if trigger_info["mtype"] == "voice":  # audio/ogg
-        voice_format = str(trigger_info["mime_type"]).split("/")[-1]  # set voice format
-    elif trigger_info["mtype"] in ["audio", "video"]:
-        path = convert_to_audio(path, ext="m4a")
-        voice_format = "m4a"
 
-    if not Path(path).expanduser().resolve().is_file():
+    duration = trigger_info["duration"]
+    asr_method, supported_ext = get_asr_method(duration, trigger_info["file_size"])
+    if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr"]:
+        await modify_progress(text=asr_method, force_update=True, **kwargs)
+        return
+
+    path: str | Path = await trigger_message.download()  # type: ignore
+    path = Path(path).expanduser().resolve()
+    if not path.is_file():
         msg = "Failed to download audio, please try again later."
         logger.error(msg)
         await modify_progress(text=msg, force_update=True, **kwargs)
         return
-    path = Path(path).expanduser().resolve()
-
-    if not voice_format:
-        voice_format = path.suffix.removeprefix(".")
-
-    # fix format code
-    if voice_format in ["oga", "ogg", "opus"]:
-        voice_format = "ogg-opus"
-    if voice_format == "mp4":
+    voice_format = path.suffix.lstrip(".")
+    if voice_format not in supported_ext:
+        path = convert_to_audio(path, ext="m4a")
         voice_format = "m4a"
-    if voice_format not in ["m4a", "ogg-opus", "wav", "pcm", "speex", "silk", "mp3", "aac", "amr"]:
-        msg = f"Unsupported audio format: {voice_format}"
-        logger.error(msg)
-        await modify_progress(text=msg, force_update=True)
-        return
+    asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size)  # match again based on converted file
+    path = path.rename(path.with_stem(rand_string()))  # sanitize filename. (for Tencent Signature v3)
 
-    # Skip very long
-    duration = trigger_info["duration"]
-    if duration > ASR_MAX_DURATION:
-        msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
-        logger.error(msg)
-        await modify_progress(text=msg, force_update=True, **kwargs)
-        return
+    if voice_format in ["oga", "ogg", "opus"]:  # rename format
+        voice_format = "ogg-opus"
 
     logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
     try:
-        if duration < 60 and trigger_info["file_size"] < 3 * 1024 * 1024:  # 时长不能超过60s, 文件大小不能超过3MB
+        if asr_method == "single_sentence_asr":
             resp = await single_sentence_asr(path, asr_engine, voice_format)
             texts = glom(resp, "Response.Result", default="❌无法识别").replace("。", "。\n")
-        elif 60 <= duration <= 300:
+        elif asr_method == "flash_asr":
             resp = await flash_asr(path, asr_engine, voice_format)
             texts = glom(resp, "flash_result.0.text", default="❌无法识别").replace("。", "。\n")
-        elif FILE_SERVER:
+        elif asr_method == "async_asr":
             resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
             task_id = resp["Response"]["Data"]["TaskId"]
-            logger.success(task_id)
-            # task_id = 11982077075
+            logger.success(f"ASR任务提交成功, TaskID: {task_id}")
             result = await query_async_asr(task_id)
             status = result["Response"]["Data"]["StatusStr"]
             while status in ["waiting", "doing"]:
@@ -177,8 +163,6 @@ async def voice_to_text(
                 texts = re.sub(r"\[.*?\]\s*", "", texts)
             else:
                 texts = glom(result, "Response.Data.ErrorMsg", default="❌无法识别")
-        else:
-            texts = "音频过长,请联系管理员配置`FILE_SERVER`变量"
 
         final = f"{BEGINNING}\n{texts}"
         logger.success(f"{final!r}")
@@ -195,7 +179,11 @@ async def voice_to_text(
                 await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", reply_parameters=reply_parameters)
         await modify_progress(del_status=True, **kwargs)
     except Exception as e:
+        error = f"Failed to recognize audio: {e}"
+        if "resp" in locals() and resp.get("hx_error"):
+            error += f"\n{resp['hx_error']}"
         logger.error(f"Failed to recognize audio: {e}")
+        await modify_progress(text=error, force_update=True, **kwargs)
     finally:
         path.unlink(missing_ok=True)
     with contextlib.suppress(Exception):
src/networking.py
@@ -51,6 +51,7 @@ async def hx_req(
     silent: bool = False,
     mobile: bool = False,
     rformat: str = "json",  # "json", "text"
+    last_error: str = "",
 ) -> dict[str, Any]:
     """Request the given URL with the given method and return the response as a dictionary.
 
@@ -71,14 +72,14 @@ async def hx_req(
         silent (bool, optional): Whether to suppress the logs.
         mobile (bool, optional): Whether to use mobile headers.
         rformat (str, optional): The format of the response.
+        last_error (str, optional): Last error message.
 
     Returns:
         dict: {"success": bool, "data": response}
     """
     if retry > max_retry:
-        error = f"[{method}] Failed after {retry} retries: {url}"
-        logger.error(error)
-        return {"hx_error": error}
+        logger.error(f"[{method}] Failed after {retry} retries: {url}")
+        return {"hx_error": last_error}
     transport = AsyncCurlTransport(proxy=proxy, impersonate="safari_ios" if mobile else "chrome", default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True})
 
     if silent:
@@ -104,8 +105,15 @@ async def hx_req(
                 logger.trace(res)
             return res
     except Exception as e:
-        logger.error(f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}")
-        return await hx_req(url, method, headers=headers, cookies=cookies, params=params, post_json=post_json, proxy=proxy, follow_redirects=follow_redirects, check_keys=check_keys, check_kv=check_kv, timeout=timeout, retry=retry + 1, max_retry=max_retry, silent=silent, rformat=rformat)  # fmt: off
+        error = f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}"
+        if "res" in locals():
+            error += f"\n{res}"
+        elif "data" in locals():
+            error += f"\n{data}"
+        elif "response" in locals():
+            error += f"\n{response}"
+        logger.error(error)
+        return await hx_req(url, method, headers=headers, cookies=cookies, params=params, post_json=post_json, proxy=proxy, follow_redirects=follow_redirects, check_keys=check_keys, check_kv=check_kv, timeout=timeout, retry=retry + 1, max_retry=max_retry, silent=silent, rformat=rformat, last_error=error)  # fmt: off
 
 
 async def download_file(