Commit d5b62f6
Changed files (5)
src/asr/cloudflare.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import base64
+from pathlib import Path
+
+import anyio
+from glom import glom
+from loguru import logger
+
+from config import ASR
+from networking import hx_req
+from utils import seconds_to_time, strings_list
+
+
+async def cloudflare_asr(
+ path: str | Path,
+ model: str = "",
+ prompt: str = "",
+) -> dict:
+ """Cloudflare ASR.
+
+ https://developers.cloudflare.com/workers-ai/models/whisper-large-v3-turbo/
+
+ Args:
+ silent (bool, optional): If Ture, do not update the status, return all results in the end.
+
+ Returns:
+ {"texts": str, "error": str}
+ """
+ path = Path(path)
+ res = {}
+ if not ASR.CLOUDFLARE_KEYS:
+ return {"error": "未配置Cloudflare相关API"}
+ if not model:
+ model = ASR.CLOUDFLARE_MODEL
+ for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
+ cf_id, cf_token = key.split(":", 1)
+ try:
+ url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
+ headers = {"Authorization": f"Bearer {cf_token}"}
+ async with await anyio.open_file(path, "rb") as f:
+ audio_bytes = await f.read()
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+ payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": True}
+ if prompt:
+ payload["initial_prompt"] = prompt
+ resp = await hx_req(url, "POST", headers=headers, json_data=payload, check_kv={"success": True}, check_keys=["result"])
+ if texts := generate_transcription(glom(resp, "result.segments", default=[])):
+ return {"texts": texts}
+ except Exception as e:
+ logger.error(e)
+ return res
+
+
+def generate_transcription(items: list[dict]) -> str:
+ res = ""
+ for item in items:
+ sentence: str = item["text"]
+ if not sentence:
+ continue
+ start = seconds_to_time(item["start"])
+ res += f"\n[{start}] {sentence}"
+ return res.strip()
src/asr/utils.py
@@ -32,6 +32,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
return get_tencent_asr_method(duration, file_size)
if force_engine == "gemini":
return get_gemini_asr_method(duration)
+ if force_engine == "cloudflare":
+ return "cloudflare", ["mp3", "opus"]
if asr_engine == "ali":
return get_ali_asr_method()
@@ -41,6 +43,8 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
return get_tencent_asr_method(duration, file_size)
if asr_engine.lower() == "gemini":
return get_gemini_asr_method(duration)
+ if asr_engine.lower() == "cloudflare":
+ return "cloudflare", ["mp3", "opus"]
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
src/asr/voice_recognition.py
@@ -11,6 +11,7 @@ from pyrogram.enums import ParseMode
from pyrogram.types import Message
from asr.ali_asr import ali_asr
+from asr.cloudflare import cloudflare_asr
from asr.deepgram import deepgram_asr
from asr.gemini_asr import gemini_stream_asr
from asr.tecent_asr import tencent_async_asr, tencent_flash_asr, tencent_single_asr
@@ -121,7 +122,7 @@ async def voice_to_text(
custom_code = custom_code.replace("fy", "zh_dialect")
if f"16k_{custom_code}" in ENGINE_MAP:
asr_language = f"16k_{custom_code}"
- elif custom_code in ["gemini", "tencent", "ali", "deepgram"]:
+ elif custom_code in ["gemini", "tencent", "ali", "deepgram", "cloudflare"]:
force_engine = custom_code
msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
@@ -193,7 +194,7 @@ async def asr_file(
if duration == 0:
duration = info["duration"]
asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
- if asr_method not in ["ali", "deepgram", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
+ if asr_method not in ["ali", "deepgram", "cloudflare", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
return {"error": asr_method}
voice_format = path.suffix.lstrip(".")
@@ -228,6 +229,8 @@ async def asr_file(
res = await deepgram_asr(path)
elif asr_method == "gemini":
res = await gemini_stream_asr(path=path, voice_format=voice_format, delete_gemini_file=delete_gemini_file, **kwargs)
+ elif asr_method == "cloudflare":
+ res = await cloudflare_asr(path=path, model=kwargs.get("cf_asr_model", ""), prompt=kwargs.get("cf_asr_prompt", ""))
else:
return {"error": "ASR method not supported"}
if res.get("texts"):
src/config.py
@@ -255,6 +255,8 @@ class ASR:
# Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local, uguu or alist.
DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
+ CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
+ CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "") # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
class PODCAST:
src/utils.py
@@ -218,8 +218,30 @@ def stringfy(d: dict) -> dict:
return d # Return non-dict, non-list values as is
+def seconds_to_time(seconds: float) -> str:
+ """Seconds to time string.
+
+ 100 -> "01:40"
+ 1000 -> "16:40"
+ 10000 -> "02:46:40"
+ 100000 -> "27:46:40"
+ """
+ seconds = round(float(seconds))
+ m, s = divmod(seconds, 60)
+ h, m = divmod(m, 60)
+ if h:
+ return f"{h:02d}:{m:02d}:{s:02d}"
+ return f"{m:02d}:{s:02d}"
+
+
def readable_time(seconds: str | float) -> str:
- """Human readable time duration."""
+ """Human readable time duration.
+
+ 100 -> "1m40s"
+ 1000 -> "16m40s"
+ 10000 -> "2h46m40s"
+ 100000 -> "1d3h46m40s"
+ """
try:
seconds = float(seconds)
except ValueError:
@@ -294,11 +316,14 @@ def slim_cid(cid: int | str) -> str:
return str(cid).strip().removeprefix("-100")
-def strings_list(value: str | None = None, *, env_key: str = "", separator: str = ",") -> list[str]:
+def strings_list(value: str | None = None, *, env_key: str = "", separator: str = ",", shuffle: bool = False) -> list[str]:
"""Get list from environment variable."""
if value is None:
value = os.getenv(env_key, "")
- return [s.strip() for s in value.split(separator) if s.strip()]
+ results = [s.strip() for s in value.split(separator) if s.strip()]
+ if shuffle:
+ random.shuffle(results)
+ return results
def parse_time(timestr: str) -> dict[str, int]: