Commit 0ccb789
src/asr/cloudflare.py
@@ -2,18 +2,16 @@
# -*- coding: utf-8 -*-
import asyncio
import base64
-import io
from collections.abc import Coroutine
from decimal import Decimal
from pathlib import Path
from typing import Any
-import soundfile as sf
from glom import glom
from loguru import logger
from asr.groq import merge_transcripts
-from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
from config import ASR
from networking import hx_req
from utils import seconds_to_time, strings_list, zhcn
@@ -111,7 +109,7 @@ async def cloudflare_file_chunks(
duration: float,
model: str | None = "",
prompt: str | None = "",
- chunk_seconds: float = ASR.CLOUDFLARE_CHUNK_SECONDS,
+ chunk_seconds: float = 600,
overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
) -> dict:
"""Transcribe audio in chunks with overlap.
@@ -127,11 +125,9 @@ async def cloudflare_file_chunks(
"""
# only support opus file
ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
- with sf.SoundFile(ogg_path, "r") as f:
- sr = f.samplerate
- audio = f.read(dtype="float32")
- logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
-
+ audio, duration, sr = load_audio(ogg_path)
+ if sr == 0:
+ return {"error": "Failed to load audio."}
transcription = {}
semaphore = asyncio.Semaphore(30) # max concurrent requests
@@ -143,7 +139,8 @@ async def cloudflare_file_chunks(
# Calculate # of chunks
total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
tasks = []
- # Loop through each chunk, extract current chunk from audio, transcribe
+ offset_list = []
+ # Loop through each chunk, extract current chunk from audio
for i in range(total_chunks):
start = int(i * (chunk_seconds - overlap_seconds) * sr)
end = int(min(start + chunk_seconds * sr, duration * sr))
@@ -151,12 +148,12 @@ async def cloudflare_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- # Write audio chunk to buffer
- buffer = io.BytesIO()
- await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
- buffer.seek(0) # move cursor to beginning
- chunk_bytes = buffer.getvalue() # get chunk bytes
- task = cloudflare_single_file(chunk_bytes, model, prompt, offset_seconds=int(start / sr))
+ tasks.append(audio_chunk_to_bytes(chunk, sr))
+ offset_list.append(int(start / sr))
+ bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
+ tasks = []
+ for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
+ task = cloudflare_single_file(audio_bytes, model, prompt, offset_seconds=offset_seconds)
tasks.append(run_with_semaphore(task))
results = await asyncio.gather(*tasks)
results = [r for r in results if r.get("segments")]
src/asr/gemini.py
@@ -5,7 +5,6 @@ import contextlib
import json
from pathlib import Path
-import soundfile as sf
from glom import glom
from google import genai
from google.genai.types import File, GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
@@ -14,7 +13,7 @@ from pydantic import BaseModel, Field
from pyrogram.types import Message
from asr.groq import merge_transcripts
-from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
+from asr.utils import GEMINI_AUDIO_EXT, audio_chunk_to_path, audio_duration, convert_single_channel, downsampe_audio, load_audio
from config import ASR, DOWNLOAD_DIR, GEMINI
from llm.hooks import hook_gemini_httpoptions
from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
@@ -130,7 +129,7 @@ async def gemini_single_file(
async def gemini_file_chunks(
message: Message,
path: str | Path,
- chunk_seconds: float = ASR.GEMINI_CHUNK_SECONDS,
+ chunk_seconds: float = 600,
overlap_seconds: float = ASR.GEMINI_OVERLAP_SECONDS,
model_id: str = "",
prompt: str = "",
@@ -149,41 +148,42 @@ async def gemini_file_chunks(
"""
path = Path(path).expanduser().resolve()
ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
- with sf.SoundFile(ogg_path, "r") as f:
- sr = f.samplerate
- audio = f.read(dtype="float32")
- duration_seconds = len(audio) / sr
- logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
-
+ audio, duration, sr = load_audio(ogg_path)
+ if sr == 0:
+ return {"error": "Failed to load audio."}
transcription = {}
try:
# Calculate # of chunks
- total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+ total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
total_chunks = int(total_chunks)
chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
tasks = []
- # Loop through each chunk, extract current chunk from audio, transcribe
+ offset_list = []
+ # Loop through each chunk, extract current chunk from audio
for i in range(total_chunks):
start = int(i * (chunk_seconds - overlap_seconds) * sr)
- end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+ end = int(min(start + chunk_seconds * sr, duration * sr))
logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- chunk_path = chunk_paths[i].as_posix()
- await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
- tasks.append(
- gemini_single_file(
- message,
- chunk_path,
- model_id=model_id,
- prompt=prompt,
- start_seconds=int(start / sr),
- delete_local_file=False,
- delete_gemini_file=delete_gemini_file,
- )
+ tasks.append(audio_chunk_to_path(chunk, sr, chunk_paths[i]))
+ offset_list.append(int(start / sr))
+ await asyncio.gather(*tasks) # convert chunks to paths
+ tasks = [
+ gemini_single_file(
+ message,
+ audio_path,
+ model_id=model_id,
+ prompt=prompt,
+ start_seconds=offset,
+ delete_local_file=False,
+ delete_gemini_file=delete_gemini_file,
)
+ for audio_path, offset in zip(chunk_paths, offset_list, strict=True)
+ ]
results = await asyncio.gather(*tasks)
+ results = [r for r in results if r.get("segments")]
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
[p.unlink(missing_ok=True) for p in chunk_paths]
except Exception as e:
src/asr/groq.py
@@ -6,11 +6,10 @@ import re
from decimal import Decimal
from pathlib import Path
-import soundfile as sf
from glom import glom
from loguru import logger
-from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
from config import ASR
from networking import hx_req
from utils import seconds_to_time, strings_list, zhcn
@@ -270,48 +269,39 @@ async def groq_file_chunks(
return {"texts": "", "error": "File not found."}
if path.suffix.lower() not in [".opus", ".ogg", ".oga"]:
path = await downsampe_audio(path, ext="opus", codec="libopus")
- try:
- with sf.SoundFile(path, "r") as f:
- if f.channels != 1:
- path = await downsampe_audio(path, ext="opus", codec="libopus")
- return await groq_file_chunks(path, chunk_seconds, overlap_seconds, model, temperature, prompt, language)
- sr = f.samplerate
- audio = f.read(dtype="float32")
- duration_seconds = len(audio) / sr
- logger.trace(f"音频时长: {duration_seconds:.2f}s, 采样率: {sr} Hz")
- except Exception as e:
- msg = f"Failed to load audio: {e!s}"
- raise RuntimeError(msg) from e
-
+ audio, duration, sr = load_audio(path)
+ if sr == 0:
+ return {"error": "Failed to load audio."}
transcription = {}
try:
# Calculate # of chunks
- total_chunks = int(duration_seconds // (chunk_seconds - overlap_seconds)) + 1
+ total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
tasks = []
- # Loop through each chunk, extract current chunk from audio, transcribe
+ offset_list = []
+ # Loop through each chunk, extract current chunk from audio
for i in range(total_chunks):
start = int(i * (chunk_seconds - overlap_seconds) * sr)
- end = int(min(start + chunk_seconds * sr, duration_seconds * sr))
+ end = int(min(start + chunk_seconds * sr, duration * sr))
logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- # Write audio chunk to buffer
- buffer = io.BytesIO()
- await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
- buffer.seek(0) # move cursor to beginning
- chunk_bytes = buffer.getvalue() # get chunk bytes
- tasks.append(
- groq_single_file(
- chunk_bytes,
- start_seconds=start / sr,
- model=model,
- temperature=temperature,
- prompt=prompt,
- language=language,
- )
+ tasks.append(audio_chunk_to_bytes(chunk, sr))
+ offset_list.append(int(start / sr))
+ bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
+ tasks = [
+ groq_single_file(
+ audio_bytes,
+ start_seconds=offset,
+ model=model,
+ temperature=temperature,
+ prompt=prompt,
+ language=language,
)
+ for audio_bytes, offset in zip(bytes_list, offset_list, strict=True)
+ ]
results = await asyncio.gather(*tasks)
+ results = [r for r in results if r.get("segments")]
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
except Exception as e:
logger.error(e)
src/asr/tecent.py
@@ -4,19 +4,17 @@ import asyncio
import base64
import hashlib
import hmac
-import io
from collections.abc import Coroutine
from decimal import Decimal
from pathlib import Path
from typing import Any
import anyio
-import soundfile as sf
from glom import Coalesce, flatten, glom
from loguru import logger
from asr.groq import merge_transcripts
-from asr.utils import audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word
+from asr.utils import audio_chunk_to_bytes, audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word, load_audio
from config import ASR, FILE_SERVER
from database.alist import delete_alist, upload_alist
from database.uguu import upload_uguu
@@ -197,10 +195,9 @@ async def tencent_file_chunks(
"""
# only support opus file
ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
- with sf.SoundFile(ogg_path, "r") as f:
- sr = f.samplerate
- audio = f.read(dtype="float32")
- logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+ audio, _, sr = load_audio(ogg_path)
+ if sr == 0:
+ return {"error": "Failed to load audio."}
transcription = {}
semaphore = asyncio.Semaphore(30) # max concurrent requests
@@ -213,7 +210,9 @@ async def tencent_file_chunks(
# Calculate # of chunks
total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
tasks = []
- # Loop through each chunk, extract current chunk from audio, transcribe
+ offset_list = []
+
+ # Loop through each chunk, extract current chunk from audio
for i in range(total_chunks):
start = int(i * (chunk_seconds - overlap_seconds) * sr)
end = int(min(start + chunk_seconds * sr, duration * sr))
@@ -221,12 +220,12 @@ async def tencent_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- # Write audio chunk to buffer
- buffer = io.BytesIO()
- await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
- buffer.seek(0) # move cursor to beginning
- chunk_bytes = buffer.getvalue() # get chunk bytes
- task = tencent_single_asr(chunk_bytes, language=language, offset_seconds=int(start / sr))
+ tasks.append(audio_chunk_to_bytes(chunk, sr))
+ offset_list.append(int(start / sr))
+ bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
+ tasks = []
+ for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
+ task = tencent_single_asr(audio_bytes, language=language, offset_seconds=offset_seconds)
tasks.append(run_with_semaphore(task))
results = await asyncio.gather(*tasks)
results = [r for r in results if r.get("segments")]
src/asr/utils.py
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import contextlib
+import io
import json
import random
import re
@@ -10,6 +12,7 @@ import anyio
import soundfile as sf
from ffmpeg import FFmpeg
from loguru import logger
+from numpy import ndarray
from soundfile import LibsndfileError
from config import ASR, GEMINI
@@ -127,3 +130,26 @@ async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
async with await anyio.open_file(path_or_bytes, "rb") as f:
file_bytes = await f.read()
return file_bytes
+
+
+def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
+ with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
+ sr = f.samplerate
+ audio = f.read(dtype="float32")
+ duration = len(audio) / sr
+ logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+ return audio, duration, sr
+ return ndarray([]), 0, 0
+
+
+async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "ogg", subtype: str = "OPUS") -> bytes:
+ buffer = io.BytesIO()
+ await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
+ buffer.seek(0) # move cursor to beginning
+ return buffer.getvalue()
+
+
+async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "ogg", subtype: str = "OPUS"):
+ out_path = Path(path).expanduser().resolve()
+ out_path.parent.mkdir(exist_ok=True, parents=True)
+ await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)
src/config.py
@@ -281,16 +281,16 @@ class ASR:
DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
CLOUDFLARE_MAX_BYTES = int(os.getenv("ASR_CLOUDFLARE_MAX_BYTES", "26214400")) # 25MB (max file bytes for single file)
- CLOUDFLARE_CHUNK_SECONDS = float(os.getenv("ASR_CLOUDFLARE_CHUNK_SECONDS", "1800")) # split long audio file into chunks
+ CLOUDFLARE_CHUNK_SECONDS = float(os.getenv("ASR_CLOUDFLARE_CHUNK_SECONDS", "600")) # split long audio file into chunks
CLOUDFLARE_OVERLAP_SECONDS = float(os.getenv("ASR_CLOUDFLARE_OVERLAP_SECONDS", "5")) # overlap seconds between chunks
CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "") # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
CLOUDFLARE_PROXY = os.getenv("ASR_CLOUDFLARE_PROXY", None)
- GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "900")) # split long audio file into chunks
+ GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "600")) # split long audio file into chunks
GEMINI_OVERLAP_SECONDS = float(os.getenv("ASR_GEMINI_OVERLAP_SECONDS", "5")) # overlap seconds between chunks
GROQ_PROXY = os.getenv("ASR_GROQ_PROXY", None) # Ban CN & HK IP
GROQ_MAX_BYTES = int(os.getenv("ASR_GROQ_MAX_BYTES", "26214400")) # 25MB (max file bytes for single file)
- GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "1800")) # split long audio file into chunks
+ GROQ_CHUNK_SECONDS = float(os.getenv("ASR_GROQ_CHUNK_SECONDS", "600")) # split long audio file into chunks
GROQ_OVERLAP_SECONDS = float(os.getenv("ASR_GROQ_OVERLAP_SECONDS", "5")) # overlap seconds between chunks
GROQ_KEYS = os.getenv("ASR_GROQ_KEYS", "") # comma separated keys for load balance.
GROQ_MODELS = os.getenv("ASR_GROQ_MODELS", "whisper-large-v3") # comma separated model names.