Commit d5b62f6

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-20 13:57:02
feat(asr): support `whisper` via Cloudflare Workers AI
1 parent a2d4a0b
src/asr/cloudflare.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import base64
+from pathlib import Path
+
+import anyio
+from glom import glom
+from loguru import logger
+
+from config import ASR
+from networking import hx_req
+from utils import seconds_to_time, strings_list
+
+
+async def cloudflare_asr(
+    path: str | Path,
+    model: str = "",
+    prompt: str = "",
+) -> dict:
+    """Cloudflare ASR.
+
+    https://developers.cloudflare.com/workers-ai/models/whisper-large-v3-turbo/
+
+    Args:
+        silent (bool, optional): If Ture, do not update the status, return all results in the end.
+
+    Returns:
+        {"texts": str, "error": str}
+    """
+    path = Path(path)
+    res = {}
+    if not ASR.CLOUDFLARE_KEYS:
+        return {"error": "未配置Cloudflare相关API"}
+    if not model:
+        model = ASR.CLOUDFLARE_MODEL
+    for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
+        cf_id, cf_token = key.split(":", 1)
+        try:
+            url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
+            headers = {"Authorization": f"Bearer {cf_token}"}
+            async with await anyio.open_file(path, "rb") as f:
+                audio_bytes = await f.read()
+                audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+            payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": True}
+            if prompt:
+                payload["initial_prompt"] = prompt
+            resp = await hx_req(url, "POST", headers=headers, json_data=payload, check_kv={"success": True}, check_keys=["result"])
+            if texts := generate_transcription(glom(resp, "result.segments", default=[])):
+                return {"texts": texts}
+        except Exception as e:
+            logger.error(e)
+    return res
+
+
+def generate_transcription(items: list[dict]) -> str:
+    res = ""
+    for item in items:
+        sentence: str = item["text"]
+        if not sentence:
+            continue
+        start = seconds_to_time(item["start"])
+        res += f"\n[{start}] {sentence}"
+    return res.strip()
src/asr/utils.py
@@ -32,6 +32,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
         return get_tencent_asr_method(duration, file_size)
     if force_engine == "gemini":
         return get_gemini_asr_method(duration)
+    if force_engine == "cloudflare":
+        return "cloudflare", ["mp3", "opus"]
 
     if asr_engine == "ali":
         return get_ali_asr_method()
@@ -41,6 +43,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
         return get_tencent_asr_method(duration, file_size)
     if asr_engine.lower() == "gemini":
         return get_gemini_asr_method(duration)
+    if asr_engine.lower() == "cloudflare":
+        return "cloudflare", ["mp3", "opus"]
     return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
 
 
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.enums import ParseMode
 from pyrogram.types import Message
 
 from asr.ali_asr import ali_asr
+from asr.cloudflare import cloudflare_asr
 from asr.deepgram import deepgram_asr
 from asr.gemini_asr import gemini_stream_asr
 from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
@@ -121,7 +122,7 @@ async def voice_to_text(
         custom_code = custom_code.replace("fy", "zh_dialect")
         if f"16k_{custom_code}" in ENGINE_MAP:
             asr_language = f"16k_{custom_code}"
-        elif custom_code in ["gemini", "tencent", "ali", "deepgram"]:
+        elif custom_code in ["gemini", "tencent", "ali", "deepgram", "cloudflare"]:
             force_engine = custom_code
 
     msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
@@ -193,7 +194,7 @@ async def asr_file(
     if duration == 0:
         duration = info["duration"]
     asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
-    if asr_method not in ["ali", "deepgram", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
+    if asr_method not in ["ali", "deepgram", "cloudflare", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
         return {"error": asr_method}
 
     voice_format = path.suffix.lstrip(".")
@@ -228,6 +229,8 @@ async def asr_file(
             res = await deepgram_asr(path)
         elif asr_method == "gemini":
             res = await gemini_stream_asr(path=path, voice_format=voice_format, delete_gemini_file=delete_gemini_file, **kwargs)
+        elif asr_method == "cloudflare":
+            res = await cloudflare_asr(path=path, model=kwargs.get("cf_asr_model", ""), prompt=kwargs.get("cf_asr_prompt", ""))
         else:
             return {"error": "ASR method not supported"}
         if res.get("texts"):
src/config.py
@@ -255,6 +255,8 @@ class ASR:
     # Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
     ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local")  # local, uguu or alist.
     DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
+    CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
+    CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "")  # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
 
 
 class PODCAST:
src/utils.py
@@ -218,8 +218,30 @@ def stringfy(d: dict) -> dict:
     return d  # Return non-dict, non-list values as is
 
 
+def seconds_to_time(seconds: float) -> str:
+    """Seconds to time string.
+
+    100 -> "01:40"
+    1000 -> "16:40"
+    10000 -> "02:46:40"
+    100000 -> "27:46:40"
+    """
+    seconds = round(float(seconds))
+    m, s = divmod(seconds, 60)
+    h, m = divmod(m, 60)
+    if h:
+        return f"{h:02d}:{m:02d}:{s:02d}"
+    return f"{m:02d}:{s:02d}"
+
+
 def readable_time(seconds: str | float) -> str:
-    """Human readable time duration."""
+    """Human readable time duration.
+
+    100 -> "1m40s"
+    1000 -> "16m40s"
+    10000 -> "2h46m40s"
+    100000 -> "1d3h46m40s"
+    """
     try:
         seconds = float(seconds)
     except ValueError:
@@ -294,11 +316,14 @@ def slim_cid(cid: int | str) -> str:
     return str(cid).strip().removeprefix("-100")
 
 
-def strings_list(value: str | None = None, *, env_key: str = "", separator: str = ",") -> list[str]:
+def strings_list(value: str | None = None, *, env_key: str = "", separator: str = ",", shuffle: bool = False) -> list[str]:
     """Get list from environment variable."""
     if value is None:
         value = os.getenv(env_key, "")
-    return [s.strip() for s in value.split(separator) if s.strip()]
+    results = [s.strip() for s in value.split(separator) if s.strip()]
+    if shuffle:
+        random.shuffle(results)
+    return results
 
 
 def parse_time(timestr: str) -> dict[str, int]: