main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import base64
  5from collections.abc import Coroutine
  6from decimal import Decimal
  7from pathlib import Path
  8from typing import Any
  9
 10from glom import glom
 11from loguru import logger
 12
 13from asr.groq import merge_transcripts
 14from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
 15from config import ASR, PROXY
 16from networking import hx_req
 17from utils import seconds_to_time, strings_list, zhcn
 18
 19
 20async def cloudflare_asr(path: str | Path, duration: float, model: str | None = "") -> dict:
 21    """Cloudflare ASR.
 22
 23    https://developers.cloudflare.com/workers-ai/models/whisper-large-v3-turbo/
 24
 25    Args:
 26        silent (bool, optional): If Ture, do not update the status, return all results in the end.
 27
 28    Returns:
 29        {"texts": str, "error": str}
 30    """
 31    path = Path(path).expanduser().resolve()
 32    if not path.is_file():
 33        return {"texts": "", "error": "File not found."}
 34    audio_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 35    audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
 36    # max allowed file size is 25MB
 37    if duration < ASR.CLOUDFLARE_CHUNK_SECONDS:
 38        return await cloudflare_single_file(audio_path, model=model)
 39    return await cloudflare_file_chunks(audio_path, duration, model=model)
 40
 41
 42async def cloudflare_single_file(path_or_bytes: Path | bytes, model: str | None = "", *, offset_seconds: int = 0) -> dict:
 43    """Transcribe a single audio chunk with Groq API.
 44
 45    Returns:
 46        {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
 47    """
 48    resp = {"texts": "", "raw_texts": "", "segments": []}
 49    if not ASR.CLOUDFLARE_KEYS:
 50        return {"error": "未配置Cloudflare相关API"}
 51    if not model:
 52        model = ASR.CLOUDFLARE_MODEL
 53    audio_bytes = await get_file_bytes(path_or_bytes)
 54    if not audio_bytes:
 55        return {"error": f"Audio is empty: {path_or_bytes}"}
 56    resp = {}
 57
 58    for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
 59        cf_id, cf_token = key.split(":", 1)
 60        try:
 61            url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
 62            headers = {"Authorization": f"Bearer {cf_token}"}
 63            audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
 64            payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": False}
 65            resp = await hx_req(
 66                url,
 67                "POST",
 68                headers=headers,
 69                json_data=payload,
 70                proxy=PROXY.CLOUDFLARE,
 71                check_kv={"success": True},
 72                check_keys=["result"],
 73            )
 74            offset = Decimal(offset_seconds).quantize(Decimal(".01"))
 75            resp["segments"] = [
 76                {
 77                    "start": offset + Decimal(str(seg["start"])),
 78                    "end": offset + Decimal(str(seg["end"])),
 79                    "text": zhcn(seg.get("text", "")),
 80                }
 81                for seg in glom(resp, "result.segments", default=[])
 82            ]
 83            resp["raw_texts"] = " ".join(str(x["text"]) for x in resp["segments"])
 84            resp["texts"] = "\n".join(f"[{seconds_to_time(float(x['start']))}] {str(x['text']).lstrip()}" for x in resp["segments"])  # with timestamp
 85            if resp.get("hx_error"):
 86                resp["error"] = resp.pop("hx_error")
 87        except Exception as e:
 88            logger.error(e)
 89        return resp
 90    return resp
 91
 92
 93async def cloudflare_file_chunks(
 94    path: Path,
 95    duration: float,
 96    model: str | None = "",
 97    chunk_seconds: float = 600,
 98    overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
 99) -> dict:
100    """Transcribe audio in chunks with overlap.
101
102    Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
103
104    Args:
105        chunk_seconds: Length of each chunk in seconds
106        overlap_seconds: Overlap between chunks in seconds
107
108    Returns:
109        dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
110    """
111    # only support wav file
112    wav_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
113    audio, duration, sr = load_audio(wav_path)
114    if sr == 0:
115        return {"error": "Failed to load audio."}
116    transcription = {}
117    semaphore = asyncio.Semaphore(30)  # max concurrent requests
118
119    async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
120        async with semaphore:
121            return await task
122
123    try:
124        # Calculate # of chunks
125        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
126        tasks = []
127        offset_list = []
128        # Loop through each chunk, extract current chunk from audio
129        for i in range(total_chunks):
130            start = int(i * (chunk_seconds - overlap_seconds) * sr)
131            end = int(min(start + chunk_seconds * sr, duration * sr))
132            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
133            chunk = audio[start:end]
134            if chunk.shape[0] == 0:  # empty chunk
135                continue
136            tasks.append(audio_chunk_to_bytes(chunk, sr))
137            offset_list.append(int(start / sr))
138        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
139        # Process each chunk in parallel (DO NOT do this due to OOM for large audio files)
140        # tasks = []
141        # for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
142        #     task = cloudflare_single_file(audio_bytes, model, offset_seconds=offset_seconds)
143        #     tasks.append(run_with_semaphore(task))
144        # results = await asyncio.gather(*tasks)
145        # results = [r for r in results if r.get("segments")]
146        results = []
147        for audio_bytes, offset in zip(bytes_list, offset_list, strict=True):
148            res = await cloudflare_single_file(audio_bytes, model=model, offset_seconds=offset)
149            if res.get("segments"):
150                results.append(res)
151        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
152    except Exception as e:
153        logger.error(e)
154        return {"error": str(e)}
155    return transcription