Commit faf0b67

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-07-28 02:34:51
perf(asr): speed up splitting audio chunks for ASR
1 parent 723d854
Changed files (4)
src/asr/cloudflare.py
@@ -2,21 +2,21 @@
 # -*- coding: utf-8 -*-
 import asyncio
 import base64
+import io
 from collections.abc import Coroutine
 from decimal import Decimal
 from pathlib import Path
 from typing import Any
 
-import anyio
 import soundfile as sf
 from glom import glom
 from loguru import logger
 
 from asr.groq import merge_transcripts
-from asr.utils import convert_single_channel, downsampe_audio
-from config import ASR, DOWNLOAD_DIR
+from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from config import ASR
 from networking import hx_req
-from utils import rand_string, seconds_to_time, strings_list, zhcn
+from utils import seconds_to_time, strings_list, zhcn
 
 
 async def cloudflare_asr(
@@ -48,12 +48,11 @@ async def cloudflare_asr(
 
 
 async def cloudflare_single_file(
-    path: Path,
+    path_or_bytes: Path | bytes,
     model: str | None = "",
     prompt: str | None = "",
     *,
     offset_seconds: int = 0,
-    delete_local_file: bool = False,
 ) -> dict:
     """Transcribe a single audio chunk with Groq API.
 
@@ -65,14 +64,17 @@ async def cloudflare_single_file(
         return {"error": "未配置Cloudflare相关API"}
     if not model:
         model = ASR.CLOUDFLARE_MODEL
+    audio_bytes = await get_file_bytes(path_or_bytes)
+    if not audio_bytes:
+        return {"error": f"Audio is empty: {path_or_bytes}"}
+    resp = {}
+
     for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
         cf_id, cf_token = key.split(":", 1)
         try:
             url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
             headers = {"Authorization": f"Bearer {cf_token}"}
-            async with await anyio.open_file(path, "rb") as f:
-                audio_bytes = await f.read()
-                audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+            audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
             payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": True}
             if prompt:
                 payload["initial_prompt"] = prompt
@@ -98,8 +100,6 @@ async def cloudflare_single_file(
             resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"])  # with timestamp
             if resp.get("hx_error"):
                 resp["error"] = resp.pop("hx_error")
-            if delete_local_file:
-                path.unlink(missing_ok=True)
         except Exception as e:
             logger.error(e)
         return resp
@@ -141,9 +141,7 @@ async def cloudflare_file_chunks(
 
     try:
         # Calculate # of chunks
-        total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
-        total_chunks = int(total_chunks)
-        chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
         # Loop through each chunk, extract current chunk from audio, transcribe
         for i in range(total_chunks):
@@ -153,14 +151,16 @@ async def cloudflare_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            chunk_path = chunk_paths[i]
-            await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
-            task = cloudflare_single_file(chunk_path, model, prompt, offset_seconds=int(start / sr), delete_local_file=False)
+            # Write audio chunk to buffer
+            buffer = io.BytesIO()
+            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+            buffer.seek(0)  # move cursor to beginning
+            chunk_bytes = buffer.getvalue()  # get chunk bytes
+            task = cloudflare_single_file(chunk_bytes, model, prompt, offset_seconds=int(start / sr))
             tasks.append(run_with_semaphore(task))
         results = await asyncio.gather(*tasks)
         results = [r for r in results if r.get("segments")]
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
-        [p.unlink(missing_ok=True) for p in chunk_paths]
     except Exception as e:
         logger.error(e)
         return {"error": str(e)}
src/asr/groq.py
@@ -3,16 +3,14 @@
 import asyncio
 import io
 import re
-import tempfile
 from decimal import Decimal
 from pathlib import Path
 
-import anyio
 import soundfile as sf
 from glom import glom
 from loguru import logger
 
-from asr.utils import convert_single_channel, downsampe_audio
+from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
 from config import ASR
 from networking import hx_req
 from utils import seconds_to_time, strings_list, zhcn
@@ -39,28 +37,20 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
 
 
 async def groq_single_file(
-    path: Path | str,
+    path_or_bytes: Path | bytes,
     model: str = "",
     temperature: float = 0,
     prompt: str = "",
     language: str = "",
     start_seconds: float = 0,
-    *,
-    delete_local_file: bool = False,
 ) -> dict:
     """Transcribe a single audio chunk with Groq API.
 
     Returns:
         {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
     """
-    path = Path(path).expanduser().resolve()
-    if not path.is_file():
-        return {"texts": "", "error": "File not found."}
     if not model:
         model = strings_list(ASR.GROQ_MODELS, shuffle=True)[0]
-    async with await anyio.open_file(path, "rb") as f:
-        content = await f.read()
-        content_bytes = io.BytesIO(content)
     data = {
         "model": model,
         "temperature": str(temperature),  # must be string
@@ -70,11 +60,12 @@ async def groq_single_file(
         data["prompt"] = prompt
     if language:
         data["language"] = language
+    audio_bytes = await get_file_bytes(path_or_bytes)
     resp = await hx_req(
         "https://api.groq.com/openai/v1/audio/transcriptions",
         method="POST",
         headers={"Authorization": f"Bearer {strings_list(ASR.GROQ_KEYS, shuffle=True)[0]}"},
-        files={"file": ("chunk.ogg", content_bytes, "audio/ogg")},
+        files={"file": ("chunk.ogg", io.BytesIO(audio_bytes), "audio/ogg")},
         data=data,
         timeout=600,
         proxy=ASR.GROQ_PROXY,
@@ -94,8 +85,6 @@ async def groq_single_file(
     resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"])  # with timestamp
     if resp.get("hx_error"):
         resp["error"] = resp.pop("hx_error")
-    if delete_local_file:
-        path.unlink(missing_ok=True)
     return resp
 
 
@@ -297,8 +286,7 @@ async def groq_file_chunks(
     transcription = {}
     try:
         # Calculate # of chunks
-        total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
-        total_chunks = int(total_chunks)
+        total_chunks = int(duration_seconds // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
         # Loop through each chunk, extract current chunk from audio, transcribe
         for i in range(total_chunks):
@@ -308,20 +296,21 @@ async def groq_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
-                chunk_path = chunk_file.name
-                await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
-                tasks.append(
-                    groq_single_file(
-                        chunk_file.name,
-                        start_seconds=start / sr,
-                        model=model,
-                        temperature=temperature,
-                        prompt=prompt,
-                        language=language,
-                        delete_local_file=True,
-                    )
+            # Write audio chunk to buffer
+            buffer = io.BytesIO()
+            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+            buffer.seek(0)  # move cursor to beginning
+            chunk_bytes = buffer.getvalue()  # get chunk bytes
+            tasks.append(
+                groq_single_file(
+                    chunk_bytes,
+                    start_seconds=start / sr,
+                    model=model,
+                    temperature=temperature,
+                    prompt=prompt,
+                    language=language,
                 )
+            )
         results = await asyncio.gather(*tasks)
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
     except Exception as e:
src/asr/tecent.py
@@ -4,6 +4,7 @@ import asyncio
 import base64
 import hashlib
 import hmac
+import io
 from collections.abc import Coroutine
 from decimal import Decimal
 from pathlib import Path
@@ -15,12 +16,12 @@ from glom import Coalesce, flatten, glom
 from loguru import logger
 
 from asr.groq import merge_transcripts
-from asr.utils import audio_duration, convert_single_channel, downsampe_audio, is_english_word
-from config import ASR, DOWNLOAD_DIR, FILE_SERVER
+from asr.utils import audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word
+from config import ASR, FILE_SERVER
 from database.alist import delete_alist, upload_alist
 from database.uguu import upload_uguu
 from networking import hx_req
-from utils import nowdt, rand_string, seconds_to_time
+from utils import nowdt, seconds_to_time
 
 
 def sign(key, msg):
@@ -101,7 +102,7 @@ async def tencent_asr(path: str | Path, language: str, duration: float) -> dict:
     return await tencent_file_chunks(audio_path, language=language, duration=duration)
 
 
-async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int = 0, delete_local_file: bool = False) -> dict:
+async def tencent_single_asr(path_or_bytes: Path | bytes, language: str, *, offset_seconds: int = 0) -> dict:
     """Tencent Single Sentence ASR.
 
     一句话识别 (每月免费额度: 5000次)
@@ -118,16 +119,19 @@ async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int =
         }
     """
     final = {"texts": "", "raw_texts": "", "segments": []}
-    file_bytes = path.stat().st_size
-    # max 3 MB
-    audio_path = path if file_bytes < 3 * 1024 * 1024 else await downsampe_audio(path, ext="opus", codec="libopus")
-    voice_format = Path(audio_path).suffix.lower().lstrip(".")
-    if voice_format in ["ogg", "opus", "oga"]:  # tencnet only supports ogg-opus
+    if isinstance(path_or_bytes, Path):
+        # max 3 MB
+        file_size = path_or_bytes.stat().st_size
+        audio_path = path_or_bytes if file_size < 3 * 1024 * 1024 else await downsampe_audio(path_or_bytes, ext="opus", codec="libopus")
+        voice_format = Path(audio_path).suffix.lower().lstrip(".")
+        if voice_format in ["ogg", "opus", "oga"]:  # tencnet only supports ogg-opus
+            voice_format = "ogg-opus"
+        audio_bytes = await get_file_bytes(audio_path)
+    elif isinstance(path_or_bytes, bytes):
         voice_format = "ogg-opus"
-    async with await anyio.open_file(audio_path, "rb") as f:
-        content = await f.read()
-        data = base64.b64encode(content).decode("utf-8")
-    payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+        audio_bytes = path_or_bytes
+    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+    payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{audio_base64}"}}'
     headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
     res = await hx_req(
         "https://asr.tencentcloudapi.com",
@@ -170,9 +174,6 @@ async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int =
         segments.append({"start": start, "end": end, "text": text})
     final["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in segments)  # with timestamp
     final["segments"] = segments
-    if delete_local_file:
-        path.unlink(missing_ok=True)
-        audio_path.unlink(missing_ok=True)
     return final
 
 
@@ -210,9 +211,7 @@ async def tencent_file_chunks(
 
     try:
         # Calculate # of chunks
-        total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
-        total_chunks = int(total_chunks)
-        chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
         tasks = []
         # Loop through each chunk, extract current chunk from audio, transcribe
         for i in range(total_chunks):
@@ -222,14 +221,16 @@ async def tencent_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            chunk_path = chunk_paths[i]
-            await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
-            task = tencent_single_asr(chunk_path, language=language, offset_seconds=int(start / sr), delete_local_file=False)
+            # Write audio chunk to buffer
+            buffer = io.BytesIO()
+            await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+            buffer.seek(0)  # move cursor to beginning
+            chunk_bytes = buffer.getvalue()  # get chunk bytes
+            task = tencent_single_asr(chunk_bytes, language=language, offset_seconds=int(start / sr))
             tasks.append(run_with_semaphore(task))
         results = await asyncio.gather(*tasks)
         results = [r for r in results if r.get("segments")]
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
-        [p.unlink(missing_ok=True) for p in chunk_paths]
     except Exception as e:
         logger.error(e)
         return {"error": str(e)}
src/asr/utils.py
@@ -6,8 +6,10 @@ import random
 import re
 from pathlib import Path
 
+import anyio
 import soundfile as sf
 from ffmpeg import FFmpeg
+from loguru import logger
 from soundfile import LibsndfileError
 
 from config import ASR, GEMINI
@@ -112,3 +114,16 @@ def audio_duration(path: str | Path) -> float:
         return max(map(float, durations))
 
     return 0.0
+
+
+async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
+    file_bytes = b""
+    if isinstance(path_or_bytes, bytes):
+        return b""
+    if isinstance(path_or_bytes, (str, Path)):
+        if not Path(path_or_bytes).is_file():
+            logger.error(f"{path_or_bytes} is not exist.")
+            return b""
+        async with await anyio.open_file(path_or_bytes, "rb") as f:
+            file_bytes = await f.read()
+    return file_bytes