main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import random
5from pathlib import Path
6
7import anyio
8import soundfile as sf
9from dashscope.audio.tts import SpeechSynthesizer
10from glom import glom
11from loguru import logger
12from pyrogram.enums import ParseMode
13
14from config import DOWNLOAD_DIR, TTS
15from messages.utils import smart_split
16from tts.engines import LIMIT_FOR_MODEL, get_random_one
17from utils import markdown_to_text, rand_string, strings_list
18
19
20async def sambert_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
21 """Sambert TTS.
22
23 https://help.aliyun.com/zh/model-studio/text-to-speech
24
25 Returns:
26 {"voice": str, "duration": int, "voice_name": str, "model": str}
27 """
28 if not model:
29 config = get_random_one(engine="sambert")
30 voice_name = config["name"]
31 model = random.choice(LIMIT_FOR_MODEL.get(voice_name, ["知琪"]))
32
33 raw_texts = markdown_to_text(texts)
34 if len(raw_texts) < TTS.SAMBERT_LENGTH_LIMIT:
35 return await sambert_tts_real(texts, model, voice_name)
36 # split
37 text_list = await smart_split(texts, chars_per_string=TTS.SAMBERT_LENGTH_LIMIT, mode=ParseMode.DISABLED)
38 resp = await asyncio.gather(*[sambert_tts_real(text, model, voice_name) for text in text_list])
39 wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
40 merge_wav([r["voice"] for r in resp], wav_path)
41 return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
42
43
44async def sambert_tts_real(texts: str, model: str, voice_name: str) -> dict:
45 """Sambert TTS.
46
47 Args:
48 return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
49
50 Returns:
51 {"url": str, "duration": int, "voice_name": str, "model": str}
52 """
53 save_path = Path("/non-exist")
54 duration = 0
55 for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
56 try:
57 logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
58 response = await asyncio.to_thread(SpeechSynthesizer.call, model, markdown_to_text(texts), format="wav", word_timestamp_enabled=True, api_key=api_key)
59 if response.get_audio_data() is not None:
60 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
61 async with await anyio.open_file(save_path, "wb") as f:
62 await f.write(response.get_audio_data())
63 duration = 0
64 if timestamps := response.get_timestamps():
65 duration = glom(timestamps, "-1.end_time", default=0) / 1000
66 except Exception as e:
67 logger.error(e)
68 return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
69 return {}
70
71
72def merge_wav(wav_paths: list[str], save_path: str | Path):
73 """Merge wav files into single one."""
74 # detect sample rate, channels, subtype
75 with sf.SoundFile(wav_paths[0], "r") as f:
76 samplerate = f.samplerate
77 channels = f.channels
78 subtype = f.subtype
79
80 # write one by one
81 with sf.SoundFile(save_path, "w", samplerate=samplerate, channels=channels, subtype=subtype) as outfile:
82 for file_path in wav_paths:
83 with sf.SoundFile(file_path, "r") as infile:
84 if infile.samplerate != samplerate or infile.channels != channels:
85 logger.warning(f"{file_path}的参数不匹配")
86 continue
87 for block in infile.blocks(blocksize=1024):
88 outfile.write(block)