Commit 39e439b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-30 15:51:15
fix(groq): fix audio file type check
1 parent f2ae750
Changed files (2)
src/asr/groq.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import io
 import re
 import tempfile
@@ -11,7 +12,7 @@ import soundfile as sf
 from glom import glom
 from loguru import logger
 
-from asr.utils import downsampe_audio
+from asr.utils import COMMON_AUDIO_EXT, convert_single_channel, downsampe_audio
 from config import ASR
 from networking import hx_req
 from utils import seconds_to_time, strings_list
@@ -28,7 +29,8 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
     path = Path(path).expanduser().resolve()
     if not path.is_file():
         return {"texts": "", "error": "File not found."}
-    audio_path = path if path.suffix.lower() not in [".mp3", ".wav", ".ogg"] else await downsampe_audio(path, ext="ogg", codec="libopus")
+    audio_path = path if path.suffix.lower() in COMMON_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+    audio_path = await convert_single_channel(audio_path)
     # max allowed file size is 25MB
     if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
         return await groq_single_file(audio_path, model=model, prompt=prompt, temperature=temperature, language=language)
@@ -42,6 +44,8 @@ async def groq_single_file(
     prompt: str = "",
     language: str = "",
     start_seconds: float = 0,
+    *,
+    delete_local_file: bool = False,
 ) -> dict:
     """Transcribe a single audio chunk with Groq API.
 
@@ -89,6 +93,8 @@ async def groq_single_file(
     resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"])  # with timestamp
     if resp.get("hx_error"):
         resp["error"] = resp.pop("hx_error")
+    if delete_local_file:
+        path.unlink(missing_ok=True)
     return resp
 
 
@@ -292,36 +298,31 @@ async def groq_file_chunks(
         # Calculate # of chunks
         total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
         total_chunks = int(total_chunks)
-        results = []
+        tasks = []
         # Loop through each chunk, extract current chunk from audio, transcribe
         for i in range(total_chunks):
             start = int(i * (chunk_seconds - overlap_seconds) * sr)
             end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
-            logger.trace(f"\nProcessing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.1f}s - {end / sr:.1f}s")
+            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            with tempfile.NamedTemporaryFile(suffix=".ogg") as chunk_file:
+            with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
                 sf.write(chunk_file.name, chunk, sr, format="ogg", subtype="OPUS")
-                result = await groq_single_file(
-                    chunk_file.name,
-                    start_seconds=start / sr,
-                    model=model,
-                    temperature=temperature,
-                    prompt=prompt,
-                    language=language,
+                tasks.append(
+                    groq_single_file(
+                        chunk_file.name,
+                        start_seconds=start / sr,
+                        model=model,
+                        temperature=temperature,
+                        prompt=prompt,
+                        language=language,
+                        delete_local_file=True,
+                    )
                 )
-                results.append(result)
-        transcription = merge_transcripts(results)
+        results = await asyncio.gather(*tasks)
+        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
     except Exception as e:
         logger.error(e)
         return {"error": str(e)}
     return transcription
-
-
-if __name__ == "__main__":
-    import asyncio
-    import shutil
-
-    shutil.copyfile("test.m4a", "testbak.m4a")
-    asyncio.run(groq_file_chunks(Path("testbak.m4a")))
src/asr/utils.py
@@ -4,6 +4,9 @@ import random
 import re
 from pathlib import Path
 
+import soundfile as sf
+from soundfile import LibsndfileError
+
 from config import ASR, GEMINI
 from multimedia import convert_to_audio
 
@@ -12,6 +15,7 @@ GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", "
 DEEPGRAM_AUDIO_EXT = [".mp3", ".aac", ".flac", ".m4a", ".mp2", ".mp4", ".ogg", ".opus", ".ogg-opus", ".pcm", ".wav", ".webm"]
 TENCENT_AUDIO_EXT = [".aac", ".amr", ".m4a", ".mp3", ".oga", ".ogg-opus", ".ogg", ".opus", ".pcm", ".silk", ".speex", ".wav"]
 TENCENT_ASYNC_AUDIO_EXT = [".3gp", ".aac", ".amr", ".flac", ".flv", ".m4a", ".mp3", ".mp4", ".oga", ".ogg-opus", ".ogg", ".opus", ".wav", ".wma"]
+COMMON_AUDIO_EXT = [".mp3", ".opus", ".ogg", ".wav", ".flac", ".aac"]
 
 
 def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
@@ -33,7 +37,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     if force_engine == "gemini":
         return get_gemini_asr_method(duration)
     if force_engine in ["cloudflare", "groq"]:
-        return force_engine.lower(), ["mp3", "opus", "ogg", "wav", "flac", "aac"]
+        return force_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
 
     if asr_engine == "ali":
         return get_ali_asr_method()
@@ -44,7 +48,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     if asr_engine.lower() == "gemini":
         return get_gemini_asr_method(duration)
     if asr_engine.lower() in ["cloudflare", "groq"]:
-        return asr_engine.lower(), ["mp3", "opus", "ogg", "wav", "flac", "aac"]
+        return asr_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
     return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
 
 
@@ -89,3 +93,21 @@ async def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "lib
 
 def is_english_word(text: str) -> bool:
     return bool(re.match(r"^[a-zA-Z]+$", text))
+
+
+async def convert_single_channel(path: str | Path) -> Path:
+    path = Path(path).expanduser().resolve()
+    try:
+        with sf.SoundFile(path, "r") as f:
+            if f.channels != 1:
+                return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
+    except LibsndfileError:
+        return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
+    return path
+
+
+def audio_duration(path: str | Path) -> float:
+    with sf.SoundFile(path, "r") as f:
+        sr = f.samplerate
+        audio = f.read(dtype="float32")
+        return len(audio) / sr