main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import contextlib
5import io
6import json
7import random
8import re
9from pathlib import Path
10
11import anyio
12import soundfile as sf
13from ffmpeg import FFmpeg
14from loguru import logger
15from numpy import ndarray
16from soundfile import LibsndfileError
17
18from config import AI, ASR
19from multimedia import convert_to_audio
20from utils import strings_list
21
22GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
23
24
25def auto_choose_asr_engine(duration: float, engine: str) -> str:
26 """Get ASR engine based on duration or category."""
27 all_engines = ["ali", "tencent", "cloudflare", "groq", "gemini", "deepgram"]
28 categries = {
29 "whisper": ["cloudflare", "groq"],
30 "china": ["ali", "tencent"],
31 "uncensored": ["cloudflare", "groq", "gemini"],
32 }
33
34 def get_enabled_engines() -> list[str]:
35 enabled_engines = []
36 if all([ASR.ALI_API_KEY, ASR.ALI_MODEL, ASR.ALI_FS_ENGINE]):
37 enabled_engines.append("ali")
38 if all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY, ASR.TENCENT_FS_ENGINE]):
39 enabled_engines.append("tencent")
40 if all([ASR.CLOUDFLARE_MODEL, ASR.CLOUDFLARE_KEYS, ASR.CLOUDFLARE_CHUNK_SECONDS]):
41 enabled_engines.append("cloudflare")
42 if all([ASR.GEMINI_MODEL, AI.GEMINI_API_KEYS, AI.GEMINI_BASE_URL, ASR.GEMINI_CHUNK_SECONDS]):
43 enabled_engines.append("gemini")
44 if all([ASR.GROQ_MODELS, ASR.GROQ_KEYS, ASR.GROQ_MAX_BYTES, ASR.GROQ_CHUNK_SECONDS]):
45 enabled_engines.append("groq")
46 if ASR.DEEPGRAM_API:
47 enabled_engines.append("deepgram")
48 return enabled_engines
49
50 def parse_engines(eng: str) -> list[str]:
51 res = []
52 for x in strings_list(eng.lower()):
53 if x in all_engines:
54 res.append(x)
55 elif x in categries:
56 res.extend(categries[x])
57 enabled_engines = get_enabled_engines()
58 return [x for x in res if x in enabled_engines]
59
60 fallback_engine = "gemini" # fallback if no match
61 if not engine:
62 return fallback_engine
63
64 if engine.lower() == "auto":
65 if duration < ASR.SHORT_DURATION:
66 engines = parse_engines(ASR.SHORT_ENGINE)
67 elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
68 engines = parse_engines(ASR.MIDDLE_ENGINE)
69 else:
70 engines = parse_engines(ASR.LONG_ENGINE)
71 return random.choice(engines) if engines else fallback_engine
72
73 engines = parse_engines(engine)
74 return random.choice(engines) if engines else fallback_engine
75
76
77async def downsampe_audio(path: str | Path, ext: str = "wav", codec: str = "pcm_s16le", sample_rate: int = 16000, channel: int = 1, **kwargs) -> Path:
78 path = Path(path).expanduser().resolve()
79 if not path.is_file():
80 return path
81 return await convert_to_audio(path, ext=ext, codec=codec, ac=channel, ar=sample_rate, **kwargs)
82
83
84def is_english_word(text: str) -> bool:
85 return bool(re.match(r"^[a-zA-Z]+$", text))
86
87
88async def get_audio_channel(path: str | Path) -> int:
89 with contextlib.suppress(Exception), sf.SoundFile(path, "r") as f:
90 return f.channels
91 with contextlib.suppress(Exception):
92 ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
93 metadata = json.loads(ffprobe.execute())
94 streams = metadata.get("streams", [])
95 return len(streams)
96 return -1
97
98
99async def convert_single_channel(path: str | Path, **kwargs) -> Path:
100 path = Path(path).expanduser().resolve()
101 num_channel = await get_audio_channel(path)
102 if num_channel != 1:
103 return await downsampe_audio(path, **kwargs)
104 return path
105
106
107def audio_duration(path: str | Path) -> float:
108 with contextlib.suppress(LibsndfileError), sf.SoundFile(path) as f:
109 samplerate = f.samplerate
110 frames = f.frames
111 return frames / samplerate
112 with contextlib.suppress(Exception):
113 ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
114 metadata = json.loads(ffprobe.execute())
115 streams = metadata.get("streams", [])
116 durations = [x.get("duration", 0) for x in streams] # all channels duration (some file embed the duration in subtitle stream)
117 return max(map(float, durations))
118
119 return 0.0
120
121
122async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
123 file_bytes = b""
124 if isinstance(path_or_bytes, bytes):
125 return path_or_bytes
126 if isinstance(path_or_bytes, (str, Path)):
127 if not Path(path_or_bytes).is_file():
128 logger.error(f"{path_or_bytes} is not exist.")
129 return b""
130 async with await anyio.open_file(path_or_bytes, "rb") as f:
131 file_bytes = await f.read()
132 return file_bytes
133
134
135def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
136 with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
137 sr = f.samplerate
138 audio = f.read(dtype="float32")
139 duration = len(audio) / sr
140 logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
141 return audio, duration, sr
142 return ndarray([]), 0, 0
143
144
145async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "WAV", subtype: str = "PCM_16") -> bytes:
146 buffer = io.BytesIO()
147 await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
148 buffer.seek(0) # move cursor to beginning
149 return buffer.getvalue()
150
151
152async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "WAV", subtype: str = "PCM_16"):
153 out_path = Path(path).expanduser().resolve()
154 out_path.parent.mkdir(exist_ok=True, parents=True)
155 await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)