main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import contextlib
  5import io
  6import json
  7import random
  8import re
  9from pathlib import Path
 10
 11import anyio
 12import soundfile as sf
 13from ffmpeg import FFmpeg
 14from loguru import logger
 15from numpy import ndarray
 16from soundfile import LibsndfileError
 17
 18from config import AI, ASR
 19from multimedia import convert_to_audio
 20from utils import strings_list
 21
 22GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
 23
 24
 25def auto_choose_asr_engine(duration: float, engine: str) -> str:
 26    """Get ASR engine based on duration or category."""
 27    all_engines = ["ali", "tencent", "cloudflare", "groq", "gemini", "deepgram"]
 28    categries = {
 29        "whisper": ["cloudflare", "groq"],
 30        "china": ["ali", "tencent"],
 31        "uncensored": ["cloudflare", "groq", "gemini"],
 32    }
 33
 34    def get_enabled_engines() -> list[str]:
 35        enabled_engines = []
 36        if all([ASR.ALI_API_KEY, ASR.ALI_MODEL, ASR.ALI_FS_ENGINE]):
 37            enabled_engines.append("ali")
 38        if all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY, ASR.TENCENT_FS_ENGINE]):
 39            enabled_engines.append("tencent")
 40        if all([ASR.CLOUDFLARE_MODEL, ASR.CLOUDFLARE_KEYS, ASR.CLOUDFLARE_CHUNK_SECONDS]):
 41            enabled_engines.append("cloudflare")
 42        if all([ASR.GEMINI_MODEL, AI.GEMINI_API_KEYS, AI.GEMINI_BASE_URL, ASR.GEMINI_CHUNK_SECONDS]):
 43            enabled_engines.append("gemini")
 44        if all([ASR.GROQ_MODELS, ASR.GROQ_KEYS, ASR.GROQ_MAX_BYTES, ASR.GROQ_CHUNK_SECONDS]):
 45            enabled_engines.append("groq")
 46        if ASR.DEEPGRAM_API:
 47            enabled_engines.append("deepgram")
 48        return enabled_engines
 49
 50    def parse_engines(eng: str) -> list[str]:
 51        res = []
 52        for x in strings_list(eng.lower()):
 53            if x in all_engines:
 54                res.append(x)
 55            elif x in categries:
 56                res.extend(categries[x])
 57        enabled_engines = get_enabled_engines()
 58        return [x for x in res if x in enabled_engines]
 59
 60    fallback_engine = "gemini"  # fallback if no match
 61    if not engine:
 62        return fallback_engine
 63
 64    if engine.lower() == "auto":
 65        if duration < ASR.SHORT_DURATION:
 66            engines = parse_engines(ASR.SHORT_ENGINE)
 67        elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
 68            engines = parse_engines(ASR.MIDDLE_ENGINE)
 69        else:
 70            engines = parse_engines(ASR.LONG_ENGINE)
 71        return random.choice(engines) if engines else fallback_engine
 72
 73    engines = parse_engines(engine)
 74    return random.choice(engines) if engines else fallback_engine
 75
 76
 77async def downsampe_audio(path: str | Path, ext: str = "wav", codec: str = "pcm_s16le", sample_rate: int = 16000, channel: int = 1, **kwargs) -> Path:
 78    path = Path(path).expanduser().resolve()
 79    if not path.is_file():
 80        return path
 81    return await convert_to_audio(path, ext=ext, codec=codec, ac=channel, ar=sample_rate, **kwargs)
 82
 83
 84def is_english_word(text: str) -> bool:
 85    return bool(re.match(r"^[a-zA-Z]+$", text))
 86
 87
 88async def get_audio_channel(path: str | Path) -> int:
 89    with contextlib.suppress(Exception), sf.SoundFile(path, "r") as f:
 90        return f.channels
 91    with contextlib.suppress(Exception):
 92        ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
 93        metadata = json.loads(ffprobe.execute())
 94        streams = metadata.get("streams", [])
 95        return len(streams)
 96    return -1
 97
 98
 99async def convert_single_channel(path: str | Path, **kwargs) -> Path:
100    path = Path(path).expanduser().resolve()
101    num_channel = await get_audio_channel(path)
102    if num_channel != 1:
103        return await downsampe_audio(path, **kwargs)
104    return path
105
106
107def audio_duration(path: str | Path) -> float:
108    with contextlib.suppress(LibsndfileError), sf.SoundFile(path) as f:
109        samplerate = f.samplerate
110        frames = f.frames
111        return frames / samplerate
112    with contextlib.suppress(Exception):
113        ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
114        metadata = json.loads(ffprobe.execute())
115        streams = metadata.get("streams", [])
116        durations = [x.get("duration", 0) for x in streams]  # all channels duration (some file embed the duration in subtitle stream)
117        return max(map(float, durations))
118
119    return 0.0
120
121
122async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
123    file_bytes = b""
124    if isinstance(path_or_bytes, bytes):
125        return path_or_bytes
126    if isinstance(path_or_bytes, (str, Path)):
127        if not Path(path_or_bytes).is_file():
128            logger.error(f"{path_or_bytes} is not exist.")
129            return b""
130        async with await anyio.open_file(path_or_bytes, "rb") as f:
131            file_bytes = await f.read()
132    return file_bytes
133
134
135def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
136    with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
137        sr = f.samplerate
138        audio = f.read(dtype="float32")
139        duration = len(audio) / sr
140        logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
141        return audio, duration, sr
142    return ndarray([]), 0, 0
143
144
145async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "WAV", subtype: str = "PCM_16") -> bytes:
146    buffer = io.BytesIO()
147    await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
148    buffer.seek(0)  # move cursor to beginning
149    return buffer.getvalue()
150
151
152async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "WAV", subtype: str = "PCM_16"):
153    out_path = Path(path).expanduser().resolve()
154    out_path.parent.mkdir(exist_ok=True, parents=True)
155    await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)