Commit d66d00b
Changed files (3)
src/asr/tecent_asr.py
@@ -7,11 +7,31 @@ from pathlib import Path
import anyio
-from config import PROXY, TOKEN
+from config import ASR_MAX_DURATION, FILE_SERVER, PROXY, TOKEN
from networking import hx_req
from utils import nowdt
+def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
+ """Get tencent ASR method and supported file types."""
+ if duration > ASR_MAX_DURATION:
+ return f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
+
+ asr_method = ""
+ if duration < 60 and file_size < 3 * 1024 * 1024:
+ asr_method = "single_sentence_asr" # 一句话识别
+ supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+ elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
+ asr_method = "flash_asr" # 录音文件识别极速版
+ supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+ elif FILE_SERVER:
+ asr_method = "async_asr" # 录音文件识别 (异步请求)
+ supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
+ elif not FILE_SERVER:
+ return "音频过长, 需使用音频URL格式调用ASR\n请联系管理员配置`FILE_SERVER`变量", []
+ return asr_method, supported_ext
+
+
def sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
src/asr/voice_recognition.py
@@ -11,14 +11,14 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, FILE_SERVER, PREFIX
+from asr.tecent_asr import create_async_asr, flash_asr, get_asr_method, query_async_asr, single_sentence_asr
+from config import CAPTION_LENGTH, FILE_SERVER, PREFIX
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
from multimedia import convert_to_audio
-from utils import to_int
+from utils import rand_string, to_int
# ruff: noqa: RUF001
@@ -115,56 +115,42 @@ async def voice_to_text(
if asr_engine not in ENGINE_MAP:
await modify_progress(text=f"Unsupported ASR engine: {asr_engine}", force_update=True, **kwargs)
return
- voice_format = ""
- path: str | Path = await trigger_message.download() # type: ignore
- if trigger_info["mtype"] == "voice": # audio/ogg
- voice_format = str(trigger_info["mime_type"]).split("/")[-1] # set voice format
- elif trigger_info["mtype"] in ["audio", "video"]:
- path = convert_to_audio(path, ext="m4a")
- voice_format = "m4a"
- if not Path(path).expanduser().resolve().is_file():
+ duration = trigger_info["duration"]
+ asr_method, supported_ext = get_asr_method(duration, trigger_info["file_size"])
+ if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr"]:
+ await modify_progress(text=asr_method, force_update=True, **kwargs)
+ return
+
+ path: str | Path = await trigger_message.download() # type: ignore
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
msg = "Failed to download audio, please try again later."
logger.error(msg)
await modify_progress(text=msg, force_update=True, **kwargs)
return
- path = Path(path).expanduser().resolve()
-
- if not voice_format:
- voice_format = path.suffix.removeprefix(".")
-
- # fix format code
- if voice_format in ["oga", "ogg", "opus"]:
- voice_format = "ogg-opus"
- if voice_format == "mp4":
+ voice_format = path.suffix.lstrip(".")
+ if voice_format not in supported_ext:
+ path = convert_to_audio(path, ext="m4a")
voice_format = "m4a"
- if voice_format not in ["m4a", "ogg-opus", "wav", "pcm", "speex", "silk", "mp3", "aac", "amr"]:
- msg = f"Unsupported audio format: {voice_format}"
- logger.error(msg)
- await modify_progress(text=msg, force_update=True)
- return
+ asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size) # match again based on converted file
+ path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
- # Skip very long
- duration = trigger_info["duration"]
- if duration > ASR_MAX_DURATION:
- msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
- logger.error(msg)
- await modify_progress(text=msg, force_update=True, **kwargs)
- return
+ if voice_format in ["oga", "ogg", "opus"]: # rename format
+ voice_format = "ogg-opus"
logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
try:
- if duration < 60 and trigger_info["file_size"] < 3 * 1024 * 1024: # 时长不能超过60s, 文件大小不能超过3MB
+ if asr_method == "single_sentence_asr":
resp = await single_sentence_asr(path, asr_engine, voice_format)
texts = glom(resp, "Response.Result", default="❌无法识别").replace("。", "。\n")
- elif 60 <= duration <= 300:
+ elif asr_method == "flash_asr":
resp = await flash_asr(path, asr_engine, voice_format)
texts = glom(resp, "flash_result.0.text", default="❌无法识别").replace("。", "。\n")
- elif FILE_SERVER:
+ elif asr_method == "async_asr":
resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
task_id = resp["Response"]["Data"]["TaskId"]
- logger.success(task_id)
- # task_id = 11982077075
+ logger.success(f"ASR任务提交成功, TaskID: {task_id}")
result = await query_async_asr(task_id)
status = result["Response"]["Data"]["StatusStr"]
while status in ["waiting", "doing"]:
@@ -177,8 +163,6 @@ async def voice_to_text(
texts = re.sub(r"\[.*?\]\s*", "", texts)
else:
texts = glom(result, "Response.Data.ErrorMsg", default="❌无法识别")
- else:
- texts = "音频过长,请联系管理员配置`FILE_SERVER`变量"
final = f"{BEGINNING}\n{texts}"
logger.success(f"{final!r}")
@@ -195,7 +179,11 @@ async def voice_to_text(
await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", reply_parameters=reply_parameters)
await modify_progress(del_status=True, **kwargs)
except Exception as e:
+ error = f"Failed to recognize audio: {e}"
+ if "resp" in locals() and resp.get("hx_error"):
+ error += f"\n{resp['hx_error']}"
logger.error(f"Failed to recognize audio: {e}")
+ await modify_progress(text=error, force_update=True, **kwargs)
finally:
path.unlink(missing_ok=True)
with contextlib.suppress(Exception):
src/networking.py
@@ -51,6 +51,7 @@ async def hx_req(
silent: bool = False,
mobile: bool = False,
rformat: str = "json", # "json", "text"
+ last_error: str = "",
) -> dict[str, Any]:
"""Request the given URL with the given method and return the response as a dictionary.
@@ -71,14 +72,14 @@ async def hx_req(
silent (bool, optional): Whether to suppress the logs.
mobile (bool, optional): Whether to use mobile headers.
rformat (str, optional): The format of the response.
+ last_error (str, optional): Last error message.
Returns:
dict: {"success": bool, "data": response}
"""
if retry > max_retry:
- error = f"[{method}] Failed after {retry} retries: {url}"
- logger.error(error)
- return {"hx_error": error}
+ logger.error(f"[{method}] Failed after {retry} retries: {url}")
+ return {"hx_error": last_error}
transport = AsyncCurlTransport(proxy=proxy, impersonate="safari_ios" if mobile else "chrome", default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True})
if silent:
@@ -104,8 +105,15 @@ async def hx_req(
logger.trace(res)
return res
except Exception as e:
- logger.error(f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}")
- return await hx_req(url, method, headers=headers, cookies=cookies, params=params, post_json=post_json, proxy=proxy, follow_redirects=follow_redirects, check_keys=check_keys, check_kv=check_kv, timeout=timeout, retry=retry + 1, max_retry=max_retry, silent=silent, rformat=rformat) # fmt: off
+ error = f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}"
+ if "res" in locals():
+ error += f"\n{res}"
+ elif "data" in locals():
+ error += f"\n{data}"
+ elif "response" in locals():
+ error += f"\n{response}"
+ logger.error(error)
+ return await hx_req(url, method, headers=headers, cookies=cookies, params=params, post_json=post_json, proxy=proxy, follow_redirects=follow_redirects, check_keys=check_keys, check_kv=check_kv, timeout=timeout, retry=retry + 1, max_retry=max_retry, silent=silent, rformat=rformat, last_error=error) # fmt: off
async def download_file(