main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import contextlib
  5import io
  6import json
  7import random
  8import re
  9from pathlib import Path
 10
 11import anyio
 12import soundfile as sf
 13from ffmpeg import FFmpeg
 14from loguru import logger
 15from numpy import ndarray
 16from soundfile import LibsndfileError
 17
 18from config import AI, ASR
 19from multimedia import convert_to_audio
 20from schema import Sentence
 21from utils import strings_list
 22
 23GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
 24
 25
 26def auto_choose_asr_engine(duration: float, engine: str) -> str:
 27    """Get ASR engine based on duration or category."""
 28    all_engines = ["ali", "tencent", "cloudflare", "groq", "gemini", "deepgram"]
 29    categries = {
 30        "whisper": ["cloudflare", "groq"],
 31        "china": ["ali", "tencent"],
 32        "uncensored": ["cloudflare", "groq", "gemini"],
 33    }
 34
 35    def get_enabled_engines() -> list[str]:
 36        enabled_engines = []
 37        if all([ASR.ALI_API_KEY, ASR.ALI_MODEL, ASR.ALI_FS_ENGINE]):
 38            enabled_engines.append("ali")
 39        if all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY, ASR.TENCENT_FS_ENGINE]):
 40            enabled_engines.append("tencent")
 41        if all([ASR.CLOUDFLARE_MODEL, ASR.CLOUDFLARE_KEYS, ASR.CLOUDFLARE_CHUNK_SECONDS]):
 42            enabled_engines.append("cloudflare")
 43        if all([ASR.GEMINI_MODEL, AI.GEMINI_API_KEYS, AI.GEMINI_BASE_URL, ASR.GEMINI_CHUNK_SECONDS]):
 44            enabled_engines.append("gemini")
 45        if all([ASR.GROQ_MODELS, ASR.GROQ_KEYS, ASR.GROQ_MAX_BYTES, ASR.GROQ_CHUNK_SECONDS]):
 46            enabled_engines.append("groq")
 47        if ASR.DEEPGRAM_API:
 48            enabled_engines.append("deepgram")
 49        return enabled_engines
 50
 51    def parse_engines(eng: str) -> list[str]:
 52        res = []
 53        for x in strings_list(eng.lower()):
 54            if x in all_engines:
 55                res.append(x)
 56            elif x in categries:
 57                res.extend(categries[x])
 58        enabled_engines = get_enabled_engines()
 59        return [x for x in res if x in enabled_engines]
 60
 61    fallback_engine = "gemini"  # fallback if no match
 62    if not engine:
 63        return fallback_engine
 64
 65    if engine.lower() == "auto":
 66        if duration < ASR.SHORT_DURATION:
 67            engines = parse_engines(ASR.SHORT_ENGINE)
 68        elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
 69            engines = parse_engines(ASR.MIDDLE_ENGINE)
 70        else:
 71            engines = parse_engines(ASR.LONG_ENGINE)
 72        return random.choice(engines) if engines else fallback_engine
 73
 74    engines = parse_engines(engine)
 75    return random.choice(engines) if engines else fallback_engine
 76
 77
 78async def downsampe_audio(path: str | Path, ext: str = "wav", codec: str = "pcm_s16le", sample_rate: int = 16000, channel: int = 1, **kwargs) -> Path:
 79    path = Path(path).expanduser().resolve()
 80    if not path.is_file():
 81        return path
 82    return await convert_to_audio(path, ext=ext, codec=codec, ac=channel, ar=sample_rate, **kwargs)
 83
 84
 85def is_english_word(text: str) -> bool:
 86    return bool(re.match(r"^[a-zA-Z]+$", text))
 87
 88
 89async def get_audio_channel(path: str | Path) -> int:
 90    with contextlib.suppress(Exception), sf.SoundFile(path, "r") as f:
 91        return f.channels
 92    with contextlib.suppress(Exception):
 93        ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
 94        metadata = json.loads(ffprobe.execute())
 95        streams = metadata.get("streams", [])
 96        return len(streams)
 97    return -1
 98
 99
100async def convert_single_channel(path: str | Path, **kwargs) -> Path:
101    path = Path(path).expanduser().resolve()
102    num_channel = await get_audio_channel(path)
103    if num_channel != 1:
104        return await downsampe_audio(path, **kwargs)
105    return path
106
107
108def audio_duration(path: str | Path) -> float:
109    with contextlib.suppress(LibsndfileError), sf.SoundFile(path) as f:
110        samplerate = f.samplerate
111        frames = f.frames
112        return frames / samplerate
113    with contextlib.suppress(Exception):
114        ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
115        metadata = json.loads(ffprobe.execute())
116        streams = metadata.get("streams", [])
117        durations = [x.get("duration", 0) for x in streams]  # all channels duration (some file embed the duration in subtitle stream)
118        return max(map(float, durations))
119
120    return 0.0
121
122
123async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
124    file_bytes = b""
125    if isinstance(path_or_bytes, bytes):
126        return path_or_bytes
127    if isinstance(path_or_bytes, (str, Path)):
128        if not Path(path_or_bytes).is_file():
129            logger.error(f"{path_or_bytes} is not exist.")
130            return b""
131        async with await anyio.open_file(path_or_bytes, "rb") as f:
132            file_bytes = await f.read()
133    return file_bytes
134
135
136def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
137    with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
138        sr = f.samplerate
139        audio = f.read(dtype="float32")
140        duration = len(audio) / sr
141        logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
142        return audio, duration, sr
143    return ndarray([]), 0, 0
144
145
146async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "WAV", subtype: str = "PCM_16") -> bytes:
147    buffer = io.BytesIO()
148    await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
149    buffer.seek(0)  # move cursor to beginning
150    return buffer.getvalue()
151
152
153async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "WAV", subtype: str = "PCM_16"):
154    out_path = Path(path).expanduser().resolve()
155    out_path.parent.mkdir(exist_ok=True, parents=True)
156    await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)
157
158
159def split_transcripts(text: str | None) -> list[Sentence]:
160    """将字幕文本按时间戳分割,返回Sentence列表."""
161    if not text:
162        return []
163    text = text.strip()
164    # 定义正则表达式
165    # 1. ^ 匹配每一行的行首
166    # 2. \[ 和 \] 匹配两侧的方括号
167    # 3. ((?:\d{2}:)?\d{2}:\d{2}) 是捕获组,提取 MM:SS 或 HH:MM:SS
168    pattern = r"^\[((?:\d{2}:)?\d{2}:\d{2})\]"
169
170    # 使用 re.MULTILINE 标志,让 ^ 能够匹配文本中每一行的开头,而不仅仅是整个字符串的开头
171    parts = re.split(pattern, text, flags=re.MULTILINE)
172
173    # 此时 parts 的结构为:
174    # ['', '00:00', ' 我那天续费的时候\n', '00:05', ' 我一看... \n', ...]
175    results: list[Sentence] = []
176
177    # 第 0 个元素是第一个时间戳之前的文本(通常为空字符串),我们从第 1 个元素开始,步长为 2 遍历
178    for i in range(1, len(parts), 2):
179        start = parts[i]
180        content = parts[i + 1].strip()  # strip() 会清理掉字幕文本首尾的多余空格和换行
181        results.append(Sentence(start=start, content=content))
182
183    return results