Commit bb81a65

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-07-14 01:53:00
fix(asr): convert long audio file to opus format for Gemini ASR
1 parent d2daeb1
Changed files (1)
src
src/asr/gemini.py
@@ -3,7 +3,6 @@
 import asyncio
 import contextlib
 import json
-import tempfile
 from pathlib import Path
 
 import soundfile as sf
@@ -17,12 +16,12 @@ from pyrogram.types import Message
 
 from asr.groq import merge_transcripts
 from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
-from config import ASR, GEMINI
+from config import ASR, DOWNLOAD_DIR, GEMINI
 from llm.gemini import gemini_stream
 from llm.hooks import hook_gemini_httpoptions
 from llm.utils import shuffle_keys
 from messages.progress import modify_progress
-from utils import count_subtitles, guess_mime, seconds_to_time, strings_list, zhcn
+from utils import count_subtitles, guess_mime, rand_string, seconds_to_time, strings_list, zhcn
 
 
 class Transcription(BaseModel):
@@ -152,7 +151,8 @@ async def gemini_file_chunks(
         dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
     """
     path = Path(path).expanduser().resolve()
-    with sf.SoundFile(path, "r") as f:
+    ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
+    with sf.SoundFile(ogg_path, "r") as f:
         sr = f.samplerate
         audio = f.read(dtype="float32")
         duration_seconds = len(audio) / sr
@@ -163,6 +163,7 @@ async def gemini_file_chunks(
         # Calculate # of chunks
         total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
         total_chunks = int(total_chunks)
+        chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
         tasks = []
         # Loop through each chunk, extract current chunk from audio, transcribe
         for i in range(total_chunks):
@@ -172,22 +173,22 @@ async def gemini_file_chunks(
             chunk = audio[start:end]
             if chunk.shape[0] == 0:  # empty chunk
                 continue
-            with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
-                chunk_path = chunk_file.name
-                await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
-                tasks.append(
-                    gemini_single_file(
-                        message,
-                        chunk_file.name,
-                        model_id=model_id,
-                        prompt=prompt,
-                        start_seconds=int(start / sr),
-                        delete_local_file=True,
-                        delete_gemini_file=delete_gemini_file,
-                    )
+            chunk_path = chunk_paths[i].as_posix()
+            await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
+            tasks.append(
+                gemini_single_file(
+                    message,
+                    chunk_path,
+                    model_id=model_id,
+                    prompt=prompt,
+                    start_seconds=int(start / sr),
+                    delete_local_file=False,
+                    delete_gemini_file=delete_gemini_file,
                 )
+            )
         results = await asyncio.gather(*tasks)
         transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+        [p.unlink(missing_ok=True) for p in chunk_paths]
     except Exception as e:
         logger.error(e)
         return {"error": str(e)}