main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import contextlib
5import io
6import json
7import random
8import re
9from pathlib import Path
10
11import anyio
12import soundfile as sf
13from ffmpeg import FFmpeg
14from loguru import logger
15from numpy import ndarray
16from soundfile import LibsndfileError
17
18from config import AI, ASR
19from multimedia import convert_to_audio
20from schema import Sentence
21from utils import strings_list
22
23GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
24
25
26def auto_choose_asr_engine(duration: float, engine: str) -> str:
27 """Get ASR engine based on duration or category."""
28 all_engines = ["ali", "tencent", "cloudflare", "groq", "gemini", "deepgram"]
29 categries = {
30 "whisper": ["cloudflare", "groq"],
31 "china": ["ali", "tencent"],
32 "uncensored": ["cloudflare", "groq", "gemini"],
33 }
34
35 def get_enabled_engines() -> list[str]:
36 enabled_engines = []
37 if all([ASR.ALI_API_KEY, ASR.ALI_MODEL, ASR.ALI_FS_ENGINE]):
38 enabled_engines.append("ali")
39 if all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY, ASR.TENCENT_FS_ENGINE]):
40 enabled_engines.append("tencent")
41 if all([ASR.CLOUDFLARE_MODEL, ASR.CLOUDFLARE_KEYS, ASR.CLOUDFLARE_CHUNK_SECONDS]):
42 enabled_engines.append("cloudflare")
43 if all([ASR.GEMINI_MODEL, AI.GEMINI_API_KEYS, AI.GEMINI_BASE_URL, ASR.GEMINI_CHUNK_SECONDS]):
44 enabled_engines.append("gemini")
45 if all([ASR.GROQ_MODELS, ASR.GROQ_KEYS, ASR.GROQ_MAX_BYTES, ASR.GROQ_CHUNK_SECONDS]):
46 enabled_engines.append("groq")
47 if ASR.DEEPGRAM_API:
48 enabled_engines.append("deepgram")
49 return enabled_engines
50
51 def parse_engines(eng: str) -> list[str]:
52 res = []
53 for x in strings_list(eng.lower()):
54 if x in all_engines:
55 res.append(x)
56 elif x in categries:
57 res.extend(categries[x])
58 enabled_engines = get_enabled_engines()
59 return [x for x in res if x in enabled_engines]
60
61 fallback_engine = "gemini" # fallback if no match
62 if not engine:
63 return fallback_engine
64
65 if engine.lower() == "auto":
66 if duration < ASR.SHORT_DURATION:
67 engines = parse_engines(ASR.SHORT_ENGINE)
68 elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
69 engines = parse_engines(ASR.MIDDLE_ENGINE)
70 else:
71 engines = parse_engines(ASR.LONG_ENGINE)
72 return random.choice(engines) if engines else fallback_engine
73
74 engines = parse_engines(engine)
75 return random.choice(engines) if engines else fallback_engine
76
77
78async def downsampe_audio(path: str | Path, ext: str = "wav", codec: str = "pcm_s16le", sample_rate: int = 16000, channel: int = 1, **kwargs) -> Path:
79 path = Path(path).expanduser().resolve()
80 if not path.is_file():
81 return path
82 return await convert_to_audio(path, ext=ext, codec=codec, ac=channel, ar=sample_rate, **kwargs)
83
84
85def is_english_word(text: str) -> bool:
86 return bool(re.match(r"^[a-zA-Z]+$", text))
87
88
89async def get_audio_channel(path: str | Path) -> int:
90 with contextlib.suppress(Exception), sf.SoundFile(path, "r") as f:
91 return f.channels
92 with contextlib.suppress(Exception):
93 ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
94 metadata = json.loads(ffprobe.execute())
95 streams = metadata.get("streams", [])
96 return len(streams)
97 return -1
98
99
100async def convert_single_channel(path: str | Path, **kwargs) -> Path:
101 path = Path(path).expanduser().resolve()
102 num_channel = await get_audio_channel(path)
103 if num_channel != 1:
104 return await downsampe_audio(path, **kwargs)
105 return path
106
107
108def audio_duration(path: str | Path) -> float:
109 with contextlib.suppress(LibsndfileError), sf.SoundFile(path) as f:
110 samplerate = f.samplerate
111 frames = f.frames
112 return frames / samplerate
113 with contextlib.suppress(Exception):
114 ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
115 metadata = json.loads(ffprobe.execute())
116 streams = metadata.get("streams", [])
117 durations = [x.get("duration", 0) for x in streams] # all channels duration (some file embed the duration in subtitle stream)
118 return max(map(float, durations))
119
120 return 0.0
121
122
123async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
124 file_bytes = b""
125 if isinstance(path_or_bytes, bytes):
126 return path_or_bytes
127 if isinstance(path_or_bytes, (str, Path)):
128 if not Path(path_or_bytes).is_file():
129 logger.error(f"{path_or_bytes} is not exist.")
130 return b""
131 async with await anyio.open_file(path_or_bytes, "rb") as f:
132 file_bytes = await f.read()
133 return file_bytes
134
135
136def load_audio(path: Path | str) -> tuple[ndarray, float, int]:
137 with contextlib.suppress(Exception), sf.SoundFile(Path(path).as_posix(), "r") as f:
138 sr = f.samplerate
139 audio = f.read(dtype="float32")
140 duration = len(audio) / sr
141 logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
142 return audio, duration, sr
143 return ndarray([]), 0, 0
144
145
146async def audio_chunk_to_bytes(chunk: ndarray, samplerate: int, fmt: str = "WAV", subtype: str = "PCM_16") -> bytes:
147 buffer = io.BytesIO()
148 await asyncio.to_thread(sf.write, buffer, chunk, samplerate, format=fmt, subtype=subtype)
149 buffer.seek(0) # move cursor to beginning
150 return buffer.getvalue()
151
152
153async def audio_chunk_to_path(chunk: ndarray, samplerate: int, path: str | Path, fmt: str = "WAV", subtype: str = "PCM_16"):
154 out_path = Path(path).expanduser().resolve()
155 out_path.parent.mkdir(exist_ok=True, parents=True)
156 await asyncio.to_thread(sf.write, out_path.as_posix(), chunk, samplerate, format=fmt, subtype=subtype)
157
158
159def split_transcripts(text: str | None) -> list[Sentence]:
160 """将字幕文本按时间戳分割,返回Sentence列表."""
161 if not text:
162 return []
163 text = text.strip()
164 # 定义正则表达式
165 # 1. ^ 匹配每一行的行首
166 # 2. \[ 和 \] 匹配两侧的方括号
167 # 3. ((?:\d{2}:)?\d{2}:\d{2}) 是捕获组,提取 MM:SS 或 HH:MM:SS
168 pattern = r"^\[((?:\d{2}:)?\d{2}:\d{2})\]"
169
170 # 使用 re.MULTILINE 标志,让 ^ 能够匹配文本中每一行的开头,而不仅仅是整个字符串的开头
171 parts = re.split(pattern, text, flags=re.MULTILINE)
172
173 # 此时 parts 的结构为:
174 # ['', '00:00', ' 我那天续费的时候\n', '00:05', ' 我一看... \n', ...]
175 results: list[Sentence] = []
176
177 # 第 0 个元素是第一个时间戳之前的文本(通常为空字符串),我们从第 1 个元素开始,步长为 2 遍历
178 for i in range(1, len(parts), 2):
179 start = parts[i]
180 content = parts[i + 1].strip() # strip() 会清理掉字幕文本首尾的多余空格和换行
181 results.append(Sentence(start=start, content=content))
182
183 return results