Commit 88542d4
Changed files (5)
src/asr/ali_asr.py
@@ -12,9 +12,9 @@ from glom import flatten, glom
from httpx import AsyncHTTPTransport
from loguru import logger
+from asr.utils import downsampe_audio
from config import ASR, DB, FILE_SERVER
-from database import delete_alist, upload_alist
-from multimedia import convert_to_audio
+from database import delete_alist, upload_alist, upload_uguu
from networking import hx_req
@@ -43,6 +43,10 @@ async def ali_asr(path: str | Path) -> str:
path = Path(path).expanduser().resolve()
if ASR.ALI_FS_ENGINE.lower() == "local":
url = FILE_SERVER.removesuffix("/") + "/" + path.name
+ elif ASR.ALI_FS_ENGINE.lower() == "uguu":
+ if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
+ path = downsampe_audio(path)
+ url = await upload_uguu(path) # max 100 MB for Uguu
else:
url = await upload_alist(path)
@@ -101,8 +105,8 @@ async def query_ali_asr(task_id: str, api_key: str) -> str:
def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
# convert audio file
sample_rate = 8000 if "8k" in model else 16000
- path = convert_to_audio(path, ext="opus", codec="libopus", ac=1, ar=sample_rate)
ext = "opus"
+ path = downsampe_audio(path, ext=ext, sample_rate=sample_rate, ac=1)
recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
result = recognition.call(Path(path).as_posix())
if result.status_code != 200:
src/asr/tecent_asr.py
@@ -10,7 +10,9 @@ import anyio
from glom import Coalesce, flatten, glom
from loguru import logger
-from config import ASR
+from asr.utils import downsampe_audio
+from config import ASR, FILE_SERVER
+from database import upload_alist, upload_uguu
from networking import hx_req
from utils import nowdt
@@ -130,15 +132,25 @@ async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) ->
return generate_tencent_transcription(sentence_start_ms, words)
-async def tencent_create_asr(url: str, engine: str) -> dict:
+async def tencent_async_asr(path: str | Path, engine: str) -> str:
"""Create Tencent ASR Task.
录音文件识别请求
https://cloud.tencent.com/document/api/1093/37823
"""
+ path = Path(path).expanduser().resolve()
+ if ASR.TENCENT_FS_ENGINE.lower() == "local":
+ url = FILE_SERVER.removesuffix("/") + "/" + path.name
+ elif ASR.TENCENT_FS_ENGINE.lower() == "uguu":
+ if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
+ path = downsampe_audio(path)
+ url = await upload_uguu(path) # max 100 MB for Uguu
+ else:
+ url = await upload_alist(path)
+
payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
- return await hx_req(
+ resp = await hx_req(
"https://asr.tencentcloudapi.com",
method="POST",
headers=headers,
@@ -147,6 +159,9 @@ async def tencent_create_asr(url: str, engine: str) -> dict:
proxy=ASR.TENCENT_PROXY,
check_keys=["Response.Data.TaskId"],
)
+ task_id = resp["Response"]["Data"]["TaskId"]
+ logger.success(f"ASR任务提交成功, TaskID: {task_id}")
+ return await tencent_query_asr(task_id)
async def tencent_query_asr(task_id: int) -> str:
src/asr/utils.py
@@ -1,8 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import random
+from pathlib import Path
-from config import ASR, FILE_SERVER
+from config import ASR
+from multimedia import convert_to_audio
def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
@@ -16,7 +18,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
# respect force_engine
if force_engine == "ali":
- return get_ali_asr_method(file_size)
+ return get_ali_asr_method()
if force_engine == "deepgram":
return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
if force_engine == "tencent":
@@ -25,7 +27,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
return get_gemini_asr_method(duration)
if asr_engine == "ali":
- return get_ali_asr_method(file_size)
+ return get_ali_asr_method()
if asr_engine == "deepgram":
return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
if asr_engine == "tencent":
@@ -35,17 +37,11 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
-def get_ali_asr_method(file_size: int) -> tuple[str, list[str]]:
+def get_ali_asr_method() -> tuple[str, list[str]]:
if not all([ASR.ALI_MODEL, ASR.ALI_API_KEY]):
return "请设置阿里云ASR相关环境变量", []
-
- asr_method = ""
- if FILE_SERVER and file_size < 2 * 1024 * 1024 * 1024: # 2GB
- asr_method = "ali"
- supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
- else:
- return "请联系管理员配置`FILE_SERVER`变量", []
- return asr_method, supported_ext
+ supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
+ return "ali", supported_ext
def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
@@ -59,11 +55,9 @@ def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[s
elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
asr_method = "tencent_flash_asr" # 录音文件识别极速版
supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
- elif FILE_SERVER:
+ else:
asr_method = "tencent_async_asr" # 录音文件识别 (异步请求)
supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
- elif not FILE_SERVER:
- return "音频过长, 需使用音频URL格式调用ASR\n请联系管理员配置`FILE_SERVER`变量", []
return asr_method, supported_ext
@@ -73,3 +67,10 @@ def get_gemini_asr_method(duration: float) -> tuple[str, list[str]]:
if not ASR.GEMINI_API_KEY:
return "请设置`ASR_GEMINI_API_KEY`环境变量", []
return "gemini", ["aac", "aiff", "flac", "mp3", "oga", "ogg", "opus", "wav"]
+
+
+def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "libopus", sample_rate: int = 16000, **kwargs) -> Path:
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return path
+ return convert_to_audio(path, ext=ext, codec=codec, ar=sample_rate, **kwargs)
src/asr/voice_recognition.py
@@ -13,9 +13,9 @@ from pyrogram.types import Message
from asr.ali_asr import ali_asr
from asr.deepgram import deepgram_asr
from asr.gemini_asr import gemini_stream_asr
-from asr.tecent_asr import tencent_create_asr, tencent_flash_asr, tencent_query_asr, tencent_single_asr
+from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
from asr.utils import get_asr_method
-from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
+from config import CAPTION_LENGTH, PREFIX, TEXT_LENGTH
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -210,10 +210,7 @@ async def asr_file(
elif asr_method == "tencent_flash_asr":
texts = await tencent_flash_asr(path, language, voice_format)
elif asr_method == "tencent_async_asr":
- resp = await tencent_create_asr(f"{FILE_SERVER}/{path.name}", language)
- task_id = resp["Response"]["Data"]["TaskId"]
- logger.success(f"ASR任务提交成功, TaskID: {task_id}")
- texts = await tencent_query_asr(task_id)
+ texts = await tencent_async_asr(path, language)
elif asr_method == "ali":
texts = await ali_asr(path)
elif asr_method == "deepgram":
@@ -224,9 +221,7 @@ async def asr_file(
logger.success(f"{texts!r}")
except Exception as e:
error = f"Failed to recognize audio: {e}"
- if "resp" in locals() and resp.get("hx_error"):
- error += f"\n{resp['hx_error']}"
- logger.error(f"Failed to recognize audio: {e}")
+ logger.error(error)
res["error"] = error
finally:
path.unlink(missing_ok=True)
src/config.py
@@ -266,12 +266,13 @@ class ASR:
TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None) # Banned oversea IP, need a back to China proxy
TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
+ TENCENT_FS_ENGINE = os.getenv("ASR_TENCENT_FS_ENGINE", "local") # local, uguu or alist.
# WARN: some models do not allow oversea VPS. Can upload to an alist server in China.
ALI_MODEL = os.getenv("ASR_ALI_MODEL", "paraformer-realtime-v2,paraformer-realtime-v1") # comma separated keys for load balance. e.g. "model1,model2,model3"
ALI_API_KEY = os.getenv("ASR_ALI_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
# If the bot is running on an oversea VPS, and Ali ASR model doesn't allow oversea fileserver.
# Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
- ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local or alist.
+ ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local, uguu or alist.
DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"