Commit 39e439b
src/asr/groq.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import io
import re
import tempfile
@@ -11,7 +12,7 @@ import soundfile as sf
from glom import glom
from loguru import logger
-from asr.utils import downsampe_audio
+from asr.utils import COMMON_AUDIO_EXT, convert_single_channel, downsampe_audio
from config import ASR
from networking import hx_req
from utils import seconds_to_time, strings_list
@@ -28,7 +29,8 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"texts": "", "error": "File not found."}
- audio_path = path if path.suffix.lower() not in [".mp3", ".wav", ".ogg"] else await downsampe_audio(path, ext="ogg", codec="libopus")
+ audio_path = path if path.suffix.lower() in COMMON_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
# max allowed file size is 25MB
if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
return await groq_single_file(audio_path, model=model, prompt=prompt, temperature=temperature, language=language)
@@ -42,6 +44,8 @@ async def groq_single_file(
prompt: str = "",
language: str = "",
start_seconds: float = 0,
+ *,
+ delete_local_file: bool = False,
) -> dict:
"""Transcribe a single audio chunk with Groq API.
@@ -89,6 +93,8 @@ async def groq_single_file(
resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"]) # with timestamp
if resp.get("hx_error"):
resp["error"] = resp.pop("hx_error")
+ if delete_local_file:
+ path.unlink(missing_ok=True)
return resp
@@ -292,36 +298,31 @@ async def groq_file_chunks(
# Calculate # of chunks
total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
total_chunks = int(total_chunks)
- results = []
+ tasks = []
# Loop through each chunk, extract current chunk from audio, transcribe
for i in range(total_chunks):
start = int(i * (chunk_seconds - overlap_seconds) * sr)
end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
- logger.trace(f"\nProcessing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.1f}s - {end / sr:.1f}s")
+ logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- with tempfile.NamedTemporaryFile(suffix=".ogg") as chunk_file:
+ with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
sf.write(chunk_file.name, chunk, sr, format="ogg", subtype="OPUS")
- result = await groq_single_file(
- chunk_file.name,
- start_seconds=start / sr,
- model=model,
- temperature=temperature,
- prompt=prompt,
- language=language,
+ tasks.append(
+ groq_single_file(
+ chunk_file.name,
+ start_seconds=start / sr,
+ model=model,
+ temperature=temperature,
+ prompt=prompt,
+ language=language,
+ delete_local_file=True,
+ )
)
- results.append(result)
- transcription = merge_transcripts(results)
+ results = await asyncio.gather(*tasks)
+ transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
except Exception as e:
logger.error(e)
return {"error": str(e)}
return transcription
-
-
-if __name__ == "__main__":
- import asyncio
- import shutil
-
- shutil.copyfile("test.m4a", "testbak.m4a")
- asyncio.run(groq_file_chunks(Path("testbak.m4a")))
src/asr/utils.py
@@ -4,6 +4,9 @@ import random
import re
from pathlib import Path
+import soundfile as sf
+from soundfile import LibsndfileError
+
from config import ASR, GEMINI
from multimedia import convert_to_audio
@@ -12,6 +15,7 @@ GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", "
DEEPGRAM_AUDIO_EXT = [".mp3", ".aac", ".flac", ".m4a", ".mp2", ".mp4", ".ogg", ".opus", ".ogg-opus", ".pcm", ".wav", ".webm"]
TENCENT_AUDIO_EXT = [".aac", ".amr", ".m4a", ".mp3", ".oga", ".ogg-opus", ".ogg", ".opus", ".pcm", ".silk", ".speex", ".wav"]
TENCENT_ASYNC_AUDIO_EXT = [".3gp", ".aac", ".amr", ".flac", ".flv", ".m4a", ".mp3", ".mp4", ".oga", ".ogg-opus", ".ogg", ".opus", ".wav", ".wma"]
+COMMON_AUDIO_EXT = [".mp3", ".opus", ".ogg", ".wav", ".flac", ".aac"]
def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
@@ -33,7 +37,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
if force_engine == "gemini":
return get_gemini_asr_method(duration)
if force_engine in ["cloudflare", "groq"]:
- return force_engine.lower(), ["mp3", "opus", "ogg", "wav", "flac", "aac"]
+ return force_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
if asr_engine == "ali":
return get_ali_asr_method()
@@ -44,7 +48,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
if asr_engine.lower() == "gemini":
return get_gemini_asr_method(duration)
if asr_engine.lower() in ["cloudflare", "groq"]:
- return asr_engine.lower(), ["mp3", "opus", "ogg", "wav", "flac", "aac"]
+ return asr_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
@@ -89,3 +93,21 @@ async def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "lib
def is_english_word(text: str) -> bool:
return bool(re.match(r"^[a-zA-Z]+$", text))
+
+
+async def convert_single_channel(path: str | Path) -> Path:
+ path = Path(path).expanduser().resolve()
+ try:
+ with sf.SoundFile(path, "r") as f:
+ if f.channels != 1:
+ return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
+ except LibsndfileError:
+ return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
+ return path
+
+
+def audio_duration(path: str | Path) -> float:
+ with sf.SoundFile(path, "r") as f:
+ sr = f.samplerate
+ audio = f.read(dtype="float32")
+ return len(audio) / sr