Commit 0ccb789

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-08-10 16:44:47
perf(asr): optimize audio chunk processing
1 parent 531784d
src/asr/cloudflare.py
@@ -2,18 +2,16 @@
 # -*- coding: utf-8 -*-
 import asyncio
 import base64
-import io
 from collections.abc import Coroutine
 from decimal import Decimal
 from pathlib import Path
 from typing import Any
 
-import soundfile as sf
 from glom import glom
 from loguru import logger
 
 from asr.groq import merge_transcripts
-from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
 from config import ASR
 from networking import hx_req
 from utils import seconds_to_time, strings_list, zhcn
@@ -111,7 +109,7 @@ async def cloudflare_file_chunks(
     duration: float,
     model: str | None = "",
     prompt: str | None = "",
-    chunk_seconds: float = ASR.CLOUDFLARE_CHUNK_SECONDS,
+    chunk_seconds: float = 600,
     overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
 ) -> dict:
     """Transcribe audio in chunks with overlap.
@@ -127,11 +125,9 @@ async def cloudflare_file_chunks(
     """
     # only support opus file
     ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
-    with sf.SoundFile(ogg_path, "r") as f:
-        sr = f.samplerate
-        audio = f.read(dtype="float32")
-        logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
-
+    audio, duration, sr = load_audio(ogg_path)
+    if sr == 0:
+        return {"error": "Failed to load audio."}
     transcription = {}
     semaphore = asyncio.Semaphore(30)  # max concurrent requests
 
@@ -143,7 +139,8 @@ async def cloudflare_file_chunks(
         # Calculate # of chunks
         total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
-        # Loop through each chunk, extract current chunk from audio, transcribe
+        offset_list = []
+        # Loop through each chunk, extract current chunk from audio
         for i in range(total_chunks):
             start = int(i * (chunk_seconds - overlap_seconds) * sr)
             end = int(min(start + chunk_seconds * sr, duration * sr))
@@ -151,12 +148,12 @@ async def cloudflare_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            # Write audio chunk to buffer
-            buffer = io.BytesIO()
-            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
-            buffer.seek(0)  # move cursor to beginning
-            chunk_bytes = buffer.getvalue()  # get chunk bytes
-            task = cloudflare_single_file(chunk_bytes, model, prompt, offset_seconds=int(start / sr))
+            tasks.append(audio_chunk_to_bytes(chunk, sr))
+            offset_list.append(int(start / sr))
+        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
+        tasks = []
+        for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
+            task = cloudflare_single_file(audio_bytes, model, prompt, offset_seconds=offset_seconds)
             tasks.append(run_with_semaphore(task))
         results = await asyncio.gather(*tasks)
         results = [r for r in results if r.get("segments")]
src/asr/gemini.py
@@ -5,7 +5,6 @@ import contextlib
 import json
 from pathlib import Path
 
-import soundfile as sf
 from glom import glom
 from google import genai
 from google.genai.types import File, GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
@@ -14,7 +13,7 @@ from pydantic import BaseModel, Field
 from pyrogram.types import Message
 
 from asr.groq import merge_transcripts
-from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
+from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, convert_single_channel, downsampe_audio, load_audio
 from config import ASR, DOWNLOAD_DIR, GEMINI
 from llm.hooks import hook_gemini_httpoptions
 from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
@@ -130,7 +129,7 @@ async def gemini_single_file(
 async def gemini_file_chunks(
     message: Message,
     path: str | Path,
-    chunk_seconds: float = ASR.GEMINI_CHUNK_SECONDS,
+    chunk_seconds: float = 600,
     overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
     model_id: str = "",
     prompt: str = "",
@@ -149,41 +148,42 @@ async def gemini_file_chunks(
     """
     path = Path(path).expanduser().resolve()
     ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
-    with sf.SoundFile(ogg_path, "r") as f:
-        sr = f.samplerate
-        audio = f.read(dtype="float32")
-        duration_seconds = len(audio) / sr
-        logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
-
+    audio, duration, sr = load_audio(ogg_path)
+    if sr == 0:
+        return {"error": "Failed to load audio."}
     transcription = {}
     try:
         # Calculate # of chunks
-        total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+        total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
         total_chunks = int(total_chunks)
         chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
         tasks = []
-        # Loop through each chunk, extract current chunk from audio, transcribe
+        offset_list = []
+        # Loop through each chunk, extract current chunk from audio
         for i in range(total_chunks):
             start = int(i * (chunk_seconds - overlap_seconds) * sr)
-            end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+            end = int(min(start + chunk_seconds * sr, duration * sr))
             logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            chunk_path = chunk_paths[i].as_posix()
-            await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
-            tasks.append(
-                gemini_single_file(
-                    message,
-                    chunk_path,
-                    model_id=model_id,
-                    prompt=prompt,
-                    start_seconds=int(start / sr),
-                    delete_local_file=False,
-                    delete_gemini_file=delete_gemini_file,
-                )
+            tasks.append(audio_chunk_to_path(chunk, sr, chunk_paths[i]))
+            offset_list.append(int(start / sr))
+        await asyncio.gather(*tasks)  # convert chunks to paths
+        tasks = [
+            gemini_single_file(
+                message,
+                audio_path,
+                model_id=model_id,
+                prompt=prompt,
+                start_seconds=offset,
+                delete_local_file=False,
+                delete_gemini_file=delete_gemini_file,
             )
+            for audio_path, offset in zip(chunk_paths, offset_list, strict=True)
+        ]
         results = await asyncio.gather(*tasks)
+        results = [r for r in results if r.get("segments")]
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
         [p.unlink(missing_ok=True) for p in chunk_paths]
     except Exception as e:
src/asr/groq.py
@@ -6,11 +6,10 @@ import re
 from decimal import Decimal
 from pathlib import Path
 
-import soundfile as sf
 from glom import glom
 from loguru import logger
 
-from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
 from config import ASR
 from networking import hx_req
 from utils import seconds_to_time, strings_list, zhcn
@@ -270,48 +269,39 @@ async def groq_file_chunks(
         return {"texts": "", "error": "File not found."}
     if path.suffix.lower() not in [".opus", ".ogg", ".oga"]:
         path = await downsampe_audio(path, ext="opus", codec="libopus")
-    try:
-        with sf.SoundFile(path, "r") as f:
-            if f.channels != 1:
-                path = await downsampe_audio(path, ext="opus", codec="libopus")
-                return await groq_file_chunks(path, chunk_seconds, overlap_seconds, model, temperature, prompt, language)
-            sr = f.samplerate
-            audio = f.read(dtype="float32")
-            duration_seconds = len(audio) / sr
-            logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
-    except Exception as e:
-        msg = f"Failed to load audio: {e!s}"
-        raise RuntimeError(msg) from e
-
+    audio, duration, sr = load_audio(path)
+    if sr == 0:
+        return {"error": "Failed to load audio."}
     transcription = {}
     try:
         # Calculate # of chunks
-        total_chunks = int(duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
-        # Loop through each chunk, extract current chunk from audio, transcribe
+        offset_list = []
+        # Loop through each chunk, extract current chunk from audio
         for i in range(total_chunks):
             start = int(i * (chunk_seconds - overlap_seconds) * sr)
-            end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+            end = int(min(start + chunk_seconds * sr, duration * sr))
             logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            # Write audio chunk to buffer
-            buffer = io.BytesIO()
-            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
-            buffer.seek(0)  # move cursor to beginning
-            chunk_bytes = buffer.getvalue()  # get chunk bytes
-            tasks.append(
-                groq_single_file(
-                    chunk_bytes,
-                    start_seconds=start / sr,
-                    model=model,
-                    temperature=temperature,
-                    prompt=prompt,
-                    language=language,
-                )
+            tasks.append(audio_chunk_to_bytes(chunk, sr))
+            offset_list.append(int(start / sr))
+        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
+        tasks = [
+            groq_single_file(
+                audio_bytes,
+                start_seconds=offset,
+                model=model,
+                temperature=temperature,
+                prompt=prompt,
+                language=language,
             )
+            for audio_bytes, offset in zip(bytes_list, offset_list, strict=True)
+        ]
         results = await asyncio.gather(*tasks)
+        results = [r for r in results if r.get("segments")]
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
     except Exception as e:
         logger.error(e)
src/asr/tecent.py
@@ -4,19 +4,17 @@ import asyncio
 import base64
 import hashlib
 import hmac
-import io
 from collections.abc import Coroutine
 from decimal import Decimal
 from pathlib import Path
 from typing import Any
 
 import anyio
-import soundfile as sf
 from glom import Coalesce, flatten, glom
 from loguru import logger
 
 from asr.groq import merge_transcripts
-from asr.utils import audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word
+from asr.utils import audio_chunk_to_bytes, audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word, load_audio
 from config import ASR, FILE_SERVER
 from database.alist import delete_alist, upload_alist
 from database.uguu import upload_uguu
@@ -197,10 +195,9 @@ async def tencent_file_chunks(
     """
     # only support opus file
     ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
-    with sf.SoundFile(ogg_path, "r") as f:
-        sr = f.samplerate
-        audio = f.read(dtype="float32")
-        logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+    audio, _, sr = load_audio(ogg_path)
+    if sr == 0:
+        return {"error": "Failed to load audio."}
 
     transcription = {}
     semaphore = asyncio.Semaphore(30)  # max concurrent requests
@@ -213,7 +210,9 @@ async def tencent_file_chunks(
         # Calculate # of chunks
         total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
-        # Loop through each chunk, extract current chunk from audio, transcribe
+        offset_list = []
+
+        # Loop through each chunk, extract current chunk from audio
         for i in range(total_chunks):
             start = int(i * (chunk_seconds - overlap_seconds) * sr)
             end = int(min(start + chunk_seconds * sr, duration * sr))
@@ -221,12 +220,12 @@ async def tencent_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            # Write audio chunk to buffer
-            buffer = io.BytesIO()
-            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
-            buffer.seek(0)  # move cursor to beginning
-            chunk_bytes = buffer.getvalue()  # get chunk bytes
-            task = tencent_single_asr(chunk_bytes, language=language, offset_seconds=int(start / sr))
+            tasks.append(audio_chunk_to_bytes(chunk, sr))
+            offset_list.append(int(start / sr))
+        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
+        tasks = []
+        for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
+            task = tencent_single_asr(audio_bytes, language=language, offset_seconds=offset_seconds)
             tasks.append(run_with_semaphore(task))
         results = await asyncio.gather(*tasks)
         results = [r for r in results if r.get("segments")]
src/asr/utils.py
@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import contextlib
+import io
 import json
 import random
 import re
@@ -10,6 +12,7 @@ import anyio
 import soundfile as sf
 from ffmpeg import FFmpeg
 from loguru import logger
+from numpy import ndarray
 from soundfile import LibsndfileError
 
 from config import ASR, GEMINI
@@ -127,3 +130,26 @@ async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
         async with await anyio.open_file(path_or_bytes, "rb") as f:
             file_bytes = await f.read()
     return file_bytes
+
+
+def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
+    with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
+        sr = f.samplerate
+        audio = f.read(dtype="float32")
+        duration = len(audio) / sr
+        logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+        return audio, duration, sr
+    return ndarray([]), 0, 0
+
+
+async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "ogg", subtype: str = "OPUS") -> bytes:
+    buffer = io.BytesIO()
+    await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
+    buffer.seek(0)  # move cursor to beginning
+    return buffer.getvalue()
+
+
+async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "ogg", subtype: str = "OPUS"):
+    out_path = Path(path).expanduser().resolve()
+    out_path.parent.mkdir(exist_ok=True, parents=True)
+    await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)
src/config.py
@@ -281,16 +281,16 @@ class ASR:
     DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
     CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
     CLOUDFLARE_MAX_BYTES = int(os.getenv("ASR_CLOUDFLARE_MAX_BYTES", "26214400"))  # 25MB (max file bytes for single file)
-    CLOUDFLARE_CHUNK_SECONDS = float(os.getenv("ASR_CLOUDFLARE_CHUNK_SECONDS", "1800"))  # split long audio file into chunks
+    CLOUDFLARE_CHUNK_SECONDS = float(os.getenv("ASR_CLOUDFLARE_CHUNK_SECONDS", "600"))  # split long audio file into chunks
     CLOUDFLARE_OVERLAP_SECONDS = float(os.getenv("ASR_CLOUDFLARE_OVERLAP_SECONDS", "5"))  # overlap seconds between chunks
     CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "")  # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
     CLOUDFLARE_PROXY = os.getenv("ASR_CLOUDFLARE_PROXY", None)
 
-    GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "900"))  # split long audio file into chunks
+    GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "600"))  # split long audio file into chunks
     GEMINI_OVERLAP_SECONDS = float(os.getenv("ASR_GEMINI_OVERLAP_SECONDS", "5"))  # overlap seconds between chunks
     GROQ_PROXY = os.getenv("ASR_GROQ_PROXY", None)  # Ban CN & HK IP
     GROQ_MAX_BYTES = int(os.getenv("ASR_GROQ_MAX_BYTES", "26214400"))  # 25MB (max file bytes for single file)
-    GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "1800"))  # split long audio file into chunks
+    GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "600"))  # split long audio file into chunks
     GROQ_OVERLAP_SECONDS = float(os.getenv("ASR_GROQ_OVERLAP_SECONDS", "5"))  # overlap seconds between chunks
     GROQ_KEYS = os.getenv("ASR_GROQ_KEYS", "")  # comma separated keys for load balance.
     GROQ_MODELS = os.getenv("ASR_GROQ_MODELS", "whisper-large-v3")  # comma separated model names.