main
 1#!/usr/bin/env python
 2# -*- coding: utf-8 -*-
 3import asyncio
 4import random
 5from pathlib import Path
 6
 7import anyio
 8import soundfile as sf
 9from dashscope.audio.tts import SpeechSynthesizer
10from glom import glom
11from loguru import logger
12from pyrogram.enums import ParseMode
13
14from config import DOWNLOAD_DIR, TTS
15from messages.utils import smart_split
16from tts.engines import LIMIT_FOR_MODEL, get_random_one
17from utils import markdown_to_text, rand_string, strings_list
18
19
20async def sambert_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
21    """Sambert TTS.
22
23    https://help.aliyun.com/zh/model-studio/text-to-speech
24
25    Returns:
26        {"voice": str, "duration": int, "voice_name": str, "model": str}
27    """
28    if not model:
29        config = get_random_one(engine="sambert")
30        voice_name = config["name"]
31        model = random.choice(LIMIT_FOR_MODEL.get(voice_name, ["知琪"]))
32
33    raw_texts = markdown_to_text(texts)
34    if len(raw_texts) < TTS.SAMBERT_LENGTH_LIMIT:
35        return await sambert_tts_real(texts, model, voice_name)
36    # split
37    text_list = await smart_split(texts, chars_per_string=TTS.SAMBERT_LENGTH_LIMIT, mode=ParseMode.DISABLED)
38    resp = await asyncio.gather(*[sambert_tts_real(text, model, voice_name) for text in text_list])
39    wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
40    merge_wav([r["voice"] for r in resp], wav_path)
41    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
42
43
44async def sambert_tts_real(texts: str, model: str, voice_name: str) -> dict:
45    """Sambert TTS.
46
47    Args:
48        return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
49
50    Returns:
51        {"url": str, "duration": int, "voice_name": str, "model": str}
52    """
53    save_path = Path("/non-exist")
54    duration = 0
55    for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
56        try:
57            logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
58            response = await asyncio.to_thread(SpeechSynthesizer.call, model, markdown_to_text(texts), format="wav", word_timestamp_enabled=True, api_key=api_key)
59            if response.get_audio_data() is not None:
60                save_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
61                async with await anyio.open_file(save_path, "wb") as f:
62                    await f.write(response.get_audio_data())
63            duration = 0
64            if timestamps := response.get_timestamps():
65                duration = glom(timestamps, "-1.end_time", default=0) / 1000
66        except Exception as e:
67            logger.error(e)
68        return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
69    return {}
70
71
72def merge_wav(wav_paths: list[str], save_path: str | Path):
73    """Merge wav files into single one."""
74    # detect sample rate, channels, subtype
75    with sf.SoundFile(wav_paths[0], "r") as f:
76        samplerate = f.samplerate
77        channels = f.channels
78        subtype = f.subtype
79
80    # write one by one
81    with sf.SoundFile(save_path, "w", samplerate=samplerate, channels=channels, subtype=subtype) as outfile:
82        for file_path in wav_paths:
83            with sf.SoundFile(file_path, "r") as infile:
84                if infile.samplerate != samplerate or infile.channels != channels:
85                    logger.warning(f"{file_path}的参数不匹配")
86                    continue
87                for block in infile.blocks(blocksize=1024):
88                    outfile.write(block)