Commit bb81a65
Changed files (1)
src
asr
src/asr/gemini.py
@@ -3,7 +3,6 @@
import asyncio
import contextlib
import json
-import tempfile
from pathlib import Path
import soundfile as sf
@@ -17,12 +16,12 @@ from pyrogram.types import Message
from asr.groq import merge_transcripts
from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
-from config import ASR, GEMINI
+from config import ASR, DOWNLOAD_DIR, GEMINI
from llm.gemini import gemini_stream
from llm.hooks import hook_gemini_httpoptions
from llm.utils import shuffle_keys
from messages.progress import modify_progress
-from utils import count_subtitles, guess_mime, seconds_to_time, strings_list, zhcn
+from utils import count_subtitles, guess_mime, rand_string, seconds_to_time, strings_list, zhcn
class Transcription(BaseModel):
@@ -152,7 +151,8 @@ async def gemini_file_chunks(
dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
"""
path = Path(path).expanduser().resolve()
- with sf.SoundFile(path, "r") as f:
+ ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
+ with sf.SoundFile(ogg_path, "r") as f:
sr = f.samplerate
audio = f.read(dtype="float32")
duration_seconds = len(audio) / sr
@@ -163,6 +163,7 @@ async def gemini_file_chunks(
# Calculate # of chunks
total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
total_chunks = int(total_chunks)
+ chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
tasks = []
# Loop through each chunk, extract current chunk from audio, transcribe
for i in range(total_chunks):
@@ -172,22 +173,22 @@ async def gemini_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
- chunk_path = chunk_file.name
- await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
- tasks.append(
- gemini_single_file(
- message,
- chunk_file.name,
- model_id=model_id,
- prompt=prompt,
- start_seconds=int(start / sr),
- delete_local_file=True,
- delete_gemini_file=delete_gemini_file,
- )
+ chunk_path = chunk_paths[i].as_posix()
+ await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
+ tasks.append(
+ gemini_single_file(
+ message,
+ chunk_path,
+ model_id=model_id,
+ prompt=prompt,
+ start_seconds=int(start / sr),
+ delete_local_file=False,
+ delete_gemini_file=delete_gemini_file,
)
+ )
results = await asyncio.gather(*tasks)
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+ [p.unlink(missing_ok=True) for p in chunk_paths]
except Exception as e:
logger.error(e)
return {"error": str(e)}