Commit 6f0160a

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-09 16:13:26
feat(asr): add DeepGram ASR
1 parent 356905a
src/asr/deepgram.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import random
+from pathlib import Path
+
+import anyio
+from glom import flatten, glom
+from loguru import logger
+
+from config import ASR
+from networking import hx_req
+from utils import zhcn
+
+
+async def deepgram_asr(path: str | Path) -> str:
+    """Deepgram ASR.
+
+    https://developers.deepgram.com/docs/pre-recorded-audio
+    """
+    api_keys = [x.strip() for x in ASR.DEEPGRAM_API.split(",") if x.strip()]
+    if not api_keys:
+        return "请配置DeepGram语音识别的API Key"
+    logger.debug(f"DeepGram ASR {path}")
+    headers = {"Authorization": f"Token {random.choice(api_keys)}"}
+    path = Path(path).expanduser().resolve()
+    url = "https://api.deepgram.com/v1/listen"
+    params = {"model": "nova-3-general", "detect_language": True, "punctuate": True, "smart_format": True}
+    async with await anyio.open_file(path, "rb") as f:
+        res = await hx_req(
+            url,
+            method="POST",
+            headers=headers,
+            post_content=await f.read(),
+            params=params,
+            timeout=600,
+            check_keys=["results.channels.0.alternatives.0.words"],
+        )
+    start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
+    sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
+    res = ""
+    indexs = list(range(len(sentences)))
+    for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
+        if not sentence:
+            continue
+        if idx == 0 or res.endswith((".", "。", "?", "?")):  # noqa: RUF001
+            start_seconds = float(start_time)
+            minutes = int(start_seconds // 60)
+            seconds = int(start_seconds % 60)
+            res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+        else:
+            res += sentence
+    return zhcn(res.strip())
src/asr/utils.py
@@ -17,6 +17,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     # respect force_engine
     if force_engine == "ali":
         return get_ali_asr_method(file_size)
+    if force_engine == "deepgram":
+        return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
     if force_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
     if force_engine == "gemini":
@@ -24,6 +26,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
 
     if asr_engine == "ali":
         return get_ali_asr_method(file_size)
+    if asr_engine == "deepgram":
+        return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
     if asr_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
     if asr_engine.lower() == "gemini":
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.enums import ParseMode
 from pyrogram.types import Message
 
 from asr.ali_asr import ali_asr
+from asr.deepgram import deepgram_asr
 from asr.gemini_asr import gemini_stream_asr
 from asr.tecent_asr import tencent_create_asr, tencent_flash_asr, tencent_query_asr, tencent_single_asr
 from asr.utils import get_asr_method
@@ -112,7 +113,7 @@ async def voice_to_text(
         custom_code = custom_code.replace("fy", "zh_dialect")
         if f"16k_{custom_code}" in ENGINE_MAP:
             asr_language = f"16k_{custom_code}"
-        elif custom_code in ["gemini", "tencent", "ali"]:
+        elif custom_code in ["gemini", "tencent", "ali", "deepgram"]:
             force_engine = custom_code
 
     msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
@@ -182,7 +183,7 @@ async def asr_file(
     if duration == 0:
         duration = info["duration"]
     asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
-    if asr_method not in ["ali", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
+    if asr_method not in ["ali", "deepgram", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
         return {"error": asr_method}
 
     voice_format = path.suffix.lstrip(".")
@@ -190,8 +191,8 @@ async def asr_file(
         if info["audio_codec"].split("/")[-1] in supported_ext and not info["video_codec"]:
             voice_format = info["audio_codec"].split("/")[-1]
         else:
-            path = convert_to_audio(path, ext="aac", codec="aac")
-            voice_format = "aac"
+            path = convert_to_audio(path, ext="opus", codec="libopus")
+            voice_format = "opus"
     # match again based on converted file
     asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
 
@@ -215,6 +216,8 @@ async def asr_file(
             texts = await tencent_query_asr(task_id)
         elif asr_method == "ali":
             texts = await ali_asr(path)
+        elif asr_method == "deepgram":
+            texts = await deepgram_asr(path)
         elif asr_method == "gemini":
             return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
         res["texts"] = texts
src/config.py
@@ -270,6 +270,7 @@ class ASR:
     # If the bot is running on an oversea VPS, and Ali ASR model doesn't allow oversea fileserver.
     # Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
     ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local")  # local or alist.
+    DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
 
 
 class GEMINI:  # Official Gemini