Commit 6f0160a
Changed files (4)
src/asr/deepgram.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import random
+from pathlib import Path
+
+import anyio
+from glom import flatten, glom
+from loguru import logger
+
+from config import ASR
+from networking import hx_req
+from utils import zhcn
+
+
+async def deepgram_asr(path: str | Path) -> str:
+ """Deepgram ASR.
+
+ https://developers.deepgram.com/docs/pre-recorded-audio
+ """
+ api_keys = [x.strip() for x in ASR.DEEPGRAM_API.split(",") if x.strip()]
+ if not api_keys:
+ return "请配置DeepGram语音识别的API Key"
+ logger.debug(f"DeepGram ASR {path}")
+ headers = {"Authorization": f"Token {random.choice(api_keys)}"}
+ path = Path(path).expanduser().resolve()
+ url = "https://api.deepgram.com/v1/listen"
+ params = {"model": "nova-3-general", "detect_language": True, "punctuate": True, "smart_format": True}
+ async with await anyio.open_file(path, "rb") as f:
+ res = await hx_req(
+ url,
+ method="POST",
+ headers=headers,
+ post_content=await f.read(),
+ params=params,
+ timeout=600,
+ check_keys=["results.channels.0.alternatives.0.words"],
+ )
+ start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
+ sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
+ res = ""
+ indexs = list(range(len(sentences)))
+ for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
+ if not sentence:
+ continue
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
+ start_seconds = float(start_time)
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+ else:
+ res += sentence
+ return zhcn(res.strip())
src/asr/utils.py
@@ -17,6 +17,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
# respect force_engine
if force_engine == "ali":
return get_ali_asr_method(file_size)
+ if force_engine == "deepgram":
+ return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
if force_engine == "tencent":
return get_tencent_asr_method(duration, file_size)
if force_engine == "gemini":
@@ -24,6 +26,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
if asr_engine == "ali":
return get_ali_asr_method(file_size)
+ if asr_engine == "deepgram":
+ return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
if asr_engine == "tencent":
return get_tencent_asr_method(duration, file_size)
if asr_engine.lower() == "gemini":
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.enums import ParseMode
from pyrogram.types import Message
from asr.ali_asr import ali_asr
+from asr.deepgram import deepgram_asr
from asr.gemini_asr import gemini_stream_asr
from asr.tecent_asr import tencent_create_asr, tencent_flash_asr, tencent_query_asr, tencent_single_asr
from asr.utils import get_asr_method
@@ -112,7 +113,7 @@ async def voice_to_text(
custom_code = custom_code.replace("fy", "zh_dialect")
if f"16k_{custom_code}" in ENGINE_MAP:
asr_language = f"16k_{custom_code}"
- elif custom_code in ["gemini", "tencent", "ali"]:
+ elif custom_code in ["gemini", "tencent", "ali", "deepgram"]:
force_engine = custom_code
msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
@@ -182,7 +183,7 @@ async def asr_file(
if duration == 0:
duration = info["duration"]
asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
- if asr_method not in ["ali", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
+ if asr_method not in ["ali", "deepgram", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
return {"error": asr_method}
voice_format = path.suffix.lstrip(".")
@@ -190,8 +191,8 @@ async def asr_file(
if info["audio_codec"].split("/")[-1] in supported_ext and not info["video_codec"]:
voice_format = info["audio_codec"].split("/")[-1]
else:
- path = convert_to_audio(path, ext="aac", codec="aac")
- voice_format = "aac"
+ path = convert_to_audio(path, ext="opus", codec="libopus")
+ voice_format = "opus"
# match again based on converted file
asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
@@ -215,6 +216,8 @@ async def asr_file(
texts = await tencent_query_asr(task_id)
elif asr_method == "ali":
texts = await ali_asr(path)
+ elif asr_method == "deepgram":
+ texts = await deepgram_asr(path)
elif asr_method == "gemini":
return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
res["texts"] = texts
src/config.py
@@ -270,6 +270,7 @@ class ASR:
# If the bot is running on an oversea VPS, and Ali ASR model doesn't allow oversea fileserver.
# Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local or alist.
+ DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
class GEMINI: # Official Gemini