main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5from collections.abc import Coroutine
6from decimal import Decimal
7from pathlib import Path
8from typing import Any
9
10from glom import glom
11from loguru import logger
12
13from asr.groq import merge_transcripts
14from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
15from config import ASR, PROXY
16from networking import hx_req
17from utils import seconds_to_time, strings_list, zhcn
18
19
20async def cloudflare_asr(path: str | Path, duration: float, model: str | None = "") -> dict:
21 """Cloudflare ASR.
22
23 https://developers.cloudflare.com/workers-ai/models/whisper-large-v3-turbo/
24
25 Args:
26 silent (bool, optional): If Ture, do not update the status, return all results in the end.
27
28 Returns:
29 {"texts": str, "error": str}
30 """
31 path = Path(path).expanduser().resolve()
32 if not path.is_file():
33 return {"texts": "", "error": "File not found."}
34 audio_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
35 audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
36 # max allowed file size is 25MB
37 if duration < ASR.CLOUDFLARE_CHUNK_SECONDS:
38 return await cloudflare_single_file(audio_path, model=model)
39 return await cloudflare_file_chunks(audio_path, duration, model=model)
40
41
42async def cloudflare_single_file(path_or_bytes: Path | bytes, model: str | None = "", *, offset_seconds: int = 0) -> dict:
43 """Transcribe a single audio chunk with Groq API.
44
45 Returns:
46 {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
47 """
48 resp = {"texts": "", "raw_texts": "", "segments": []}
49 if not ASR.CLOUDFLARE_KEYS:
50 return {"error": "未配置Cloudflare相关API"}
51 if not model:
52 model = ASR.CLOUDFLARE_MODEL
53 audio_bytes = await get_file_bytes(path_or_bytes)
54 if not audio_bytes:
55 return {"error": f"Audio is empty: {path_or_bytes}"}
56 resp = {}
57
58 for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
59 cf_id, cf_token = key.split(":", 1)
60 try:
61 url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
62 headers = {"Authorization": f"Bearer {cf_token}"}
63 audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
64 payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": False}
65 resp = await hx_req(
66 url,
67 "POST",
68 headers=headers,
69 json_data=payload,
70 proxy=PROXY.CLOUDFLARE,
71 check_kv={"success": True},
72 check_keys=["result"],
73 )
74 offset = Decimal(offset_seconds).quantize(Decimal(".01"))
75 resp["segments"] = [
76 {
77 "start": offset + Decimal(str(seg["start"])),
78 "end": offset + Decimal(str(seg["end"])),
79 "text": zhcn(seg.get("text", "")),
80 }
81 for seg in glom(resp, "result.segments", default=[])
82 ]
83 resp["raw_texts"] = " ".join(str(x["text"]) for x in resp["segments"])
84 resp["texts"] = "\n".join(f"[{seconds_to_time(float(x['start']))}] {str(x['text']).lstrip()}" for x in resp["segments"]) # with timestamp
85 if resp.get("hx_error"):
86 resp["error"] = resp.pop("hx_error")
87 except Exception as e:
88 logger.error(e)
89 return resp
90 return resp
91
92
93async def cloudflare_file_chunks(
94 path: Path,
95 duration: float,
96 model: str | None = "",
97 chunk_seconds: float = 600,
98 overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
99) -> dict:
100 """Transcribe audio in chunks with overlap.
101
102 Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
103
104 Args:
105 chunk_seconds: Length of each chunk in seconds
106 overlap_seconds: Overlap between chunks in seconds
107
108 Returns:
109 dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
110 """
111 # only support wav file
112 wav_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
113 audio, duration, sr = load_audio(wav_path)
114 if sr == 0:
115 return {"error": "Failed to load audio."}
116 transcription = {}
117 semaphore = asyncio.Semaphore(30) # max concurrent requests
118
119 async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
120 async with semaphore:
121 return await task
122
123 try:
124 # Calculate # of chunks
125 total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
126 tasks = []
127 offset_list = []
128 # Loop through each chunk, extract current chunk from audio
129 for i in range(total_chunks):
130 start = int(i * (chunk_seconds - overlap_seconds) * sr)
131 end = int(min(start + chunk_seconds * sr, duration * sr))
132 logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
133 chunk = audio[start:end]
134 if chunk.shape[0] == 0: # empty chunk
135 continue
136 tasks.append(audio_chunk_to_bytes(chunk, sr))
137 offset_list.append(int(start / sr))
138 bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
139 # Process each chunk in parallel (DO NOT do this due to OOM for large audio files)
140 # tasks = []
141 # for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
142 # task = cloudflare_single_file(audio_bytes, model, offset_seconds=offset_seconds)
143 # tasks.append(run_with_semaphore(task))
144 # results = await asyncio.gather(*tasks)
145 # results = [r for r in results if r.get("segments")]
146 results = []
147 for audio_bytes, offset in zip(bytes_list, offset_list, strict=True):
148 res = await cloudflare_single_file(audio_bytes, model=model, offset_seconds=offset)
149 if res.get("segments"):
150 results.append(res)
151 transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
152 except Exception as e:
153 logger.error(e)
154 return {"error": str(e)}
155 return transcription