Commit 88542d4

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-10 04:31:08
feat(asr): support Uguu as file server for Tencent and Ali ASR
1 parent 4bd964b
src/asr/ali_asr.py
@@ -12,9 +12,9 @@ from glom import flatten, glom
 from httpx import AsyncHTTPTransport
 from loguru import logger
 
+from asr.utils import downsampe_audio
 from config import ASR, DB, FILE_SERVER
-from database import delete_alist, upload_alist
-from multimedia import convert_to_audio
+from database import delete_alist, upload_alist, upload_uguu
 from networking import hx_req
 
 
@@ -43,6 +43,10 @@ async def ali_asr(path: str | Path) -> str:
     path = Path(path).expanduser().resolve()
     if ASR.ALI_FS_ENGINE.lower() == "local":
         url = FILE_SERVER.removesuffix("/") + "/" + path.name
+    elif ASR.ALI_FS_ENGINE.lower() == "uguu":
+        if path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
+            path = downsampe_audio(path)
+        url = await upload_uguu(path)  # max 100 MB for Uguu
     else:
         url = await upload_alist(path)
 
@@ -101,8 +105,8 @@ async def query_ali_asr(task_id: str, api_key: str) -> str:
 def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
     # convert audio file
     sample_rate = 8000 if "8k" in model else 16000
-    path = convert_to_audio(path, ext="opus", codec="libopus", ac=1, ar=sample_rate)
     ext = "opus"
+    path = downsampe_audio(path, ext=ext, sample_rate=sample_rate, ac=1)
     recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
     result = recognition.call(Path(path).as_posix())
     if result.status_code != 200:
src/asr/tecent_asr.py
@@ -10,7 +10,9 @@ import anyio
 from glom import Coalesce, flatten, glom
 from loguru import logger
 
-from config import ASR
+from asr.utils import downsampe_audio
+from config import ASR, FILE_SERVER
+from database import upload_alist, upload_uguu
 from networking import hx_req
 from utils import nowdt
 
@@ -130,15 +132,25 @@ async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) ->
         return generate_tencent_transcription(sentence_start_ms, words)
 
 
-async def tencent_create_asr(url: str, engine: str) -> dict:
+async def tencent_async_asr(path: str | Path, engine: str) -> str:
     """Create Tencent ASR Task.
 
     录音文件识别请求
     https://cloud.tencent.com/document/api/1093/37823
     """
+    path = Path(path).expanduser().resolve()
+    if ASR.TENCENT_FS_ENGINE.lower() == "local":
+        url = FILE_SERVER.removesuffix("/") + "/" + path.name
+    elif ASR.TENCENT_FS_ENGINE.lower() == "uguu":
+        if path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
+            path = downsampe_audio(path)
+        url = await upload_uguu(path)  # max 100 MB for Uguu
+    else:
+        url = await upload_alist(path)
+
     payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
     headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
-    return await hx_req(
+    resp = await hx_req(
         "https://asr.tencentcloudapi.com",
         method="POST",
         headers=headers,
@@ -147,6 +159,9 @@ async def tencent_create_asr(url: str, engine: str) -> dict:
         proxy=ASR.TENCENT_PROXY,
         check_keys=["Response.Data.TaskId"],
     )
+    task_id = resp["Response"]["Data"]["TaskId"]
+    logger.success(f"ASR任务提交成功, TaskID: {task_id}")
+    return await tencent_query_asr(task_id)
 
 
 async def tencent_query_asr(task_id: int) -> str:
src/asr/utils.py
@@ -1,8 +1,10 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import random
+from pathlib import Path
 
-from config import ASR, FILE_SERVER
+from config import ASR
+from multimedia import convert_to_audio
 
 
 def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
@@ -16,7 +18,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
 
     # respect force_engine
     if force_engine == "ali":
-        return get_ali_asr_method(file_size)
+        return get_ali_asr_method()
     if force_engine == "deepgram":
         return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
     if force_engine == "tencent":
@@ -25,7 +27,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
         return get_gemini_asr_method(duration)
 
     if asr_engine == "ali":
-        return get_ali_asr_method(file_size)
+        return get_ali_asr_method()
     if asr_engine == "deepgram":
         return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
     if asr_engine == "tencent":
@@ -35,17 +37,11 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
 
 
-def get_ali_asr_method(file_size: int) -> tuple[str, list[str]]:
+def get_ali_asr_method() -> tuple[str, list[str]]:
     if not all([ASR.ALI_MODEL, ASR.ALI_API_KEY]):
         return "请设置阿里云ASR相关环境变量", []
-
-    asr_method = ""
-    if FILE_SERVER and file_size < 2 * 1024 * 1024 * 1024:  # 2GB
-        asr_method = "ali"
-        supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
-    else:
-        return "请联系管理员配置`FILE_SERVER`变量", []
-    return asr_method, supported_ext
+    supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
+    return "ali", supported_ext
 
 
 def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
@@ -59,11 +55,9 @@ def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[s
     elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
         asr_method = "tencent_flash_asr"  # 录音文件识别极速版
         supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
-    elif FILE_SERVER:
+    else:
         asr_method = "tencent_async_asr"  # 录音文件识别 (异步请求)
         supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
-    elif not FILE_SERVER:
-        return "音频过长, 需使用音频URL格式调用ASR\n请联系管理员配置`FILE_SERVER`变量", []
     return asr_method, supported_ext
 
 
@@ -73,3 +67,10 @@ def get_gemini_asr_method(duration: float) -> tuple[str, list[str]]:
     if not ASR.GEMINI_API_KEY:
         return "请设置`ASR_GEMINI_API_KEY`环境变量", []
     return "gemini", ["aac", "aiff", "flac", "mp3", "oga", "ogg", "opus", "wav"]
+
+
+def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "libopus", sample_rate: int = 16000, **kwargs) -> Path:
+    path = Path(path).expanduser().resolve()
+    if not path.is_file():
+        return path
+    return convert_to_audio(path, ext=ext, codec=codec, ar=sample_rate, **kwargs)
src/asr/voice_recognition.py
@@ -13,9 +13,9 @@ from pyrogram.types import Message
 from asr.ali_asr import ali_asr
 from asr.deepgram import deepgram_asr
 from asr.gemini_asr import gemini_stream_asr
-from asr.tecent_asr import tencent_create_asr, tencent_flash_asr, tencent_query_asr, tencent_single_asr
+from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
 from asr.utils import get_asr_method
-from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
+from config import CAPTION_LENGTH, PREFIX, TEXT_LENGTH
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -210,10 +210,7 @@ async def asr_file(
         elif asr_method == "tencent_flash_asr":
             texts = await tencent_flash_asr(path, language, voice_format)
         elif asr_method == "tencent_async_asr":
-            resp = await tencent_create_asr(f"{FILE_SERVER}/{path.name}", language)
-            task_id = resp["Response"]["Data"]["TaskId"]
-            logger.success(f"ASR任务提交成功, TaskID: {task_id}")
-            texts = await tencent_query_asr(task_id)
+            texts = await tencent_async_asr(path, language)
         elif asr_method == "ali":
             texts = await ali_asr(path)
         elif asr_method == "deepgram":
@@ -224,9 +221,7 @@ async def asr_file(
         logger.success(f"{texts!r}")
     except Exception as e:
         error = f"Failed to recognize audio: {e}"
-        if "resp" in locals() and resp.get("hx_error"):
-            error += f"\n{resp['hx_error']}"
-        logger.error(f"Failed to recognize audio: {e}")
+        logger.error(error)
         res["error"] = error
     finally:
         path.unlink(missing_ok=True)
src/config.py
@@ -266,12 +266,13 @@ class ASR:
     TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None)  # Banned oversea IP, need a back to China proxy
     TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
     TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
+    TENCENT_FS_ENGINE = os.getenv("ASR_TENCENT_FS_ENGINE", "local")  # local, uguu or alist.
     # WARN: some models do not allow oversea VPS. Can upload to an alist server in China.
     ALI_MODEL = os.getenv("ASR_ALI_MODEL", "paraformer-realtime-v2,paraformer-realtime-v1")  # comma separated keys for load balance. e.g. "model1,model2,model3"
     ALI_API_KEY = os.getenv("ASR_ALI_API_KEY", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     # If the bot is running on an oversea VPS, and Ali ASR model doesn't allow oversea fileserver.
     # Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
-    ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local")  # local or alist.
+    ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local")  # local, uguu or alist.
     DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"