main
 1#!/usr/bin/env python
 2# -*- coding: utf-8 -*-
 3import asyncio
 4from pathlib import Path
 5
 6import soundfile as sf
 7from dashscope import get_tokenizer
 8from glom import glom
 9from loguru import logger
10from pyrogram.enums import ParseMode
11
12from config import DOWNLOAD_DIR, PROXY, TTS
13from messages.utils import smart_split
14from networking import download_file, hx_req
15from utils import markdown_to_text, rand_string, strings_list
16
17
18async def qwen_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
19    """Qwen TTS.
20
21    https://help.aliyun.com/zh/model-studio/qwen-tts
22
23    Returns:
24        {"voice": str, "duration": int, "voice_name": str, "model": str}
25    """
26    model = model or strings_list(TTS.QWEN_MODEL, shuffle=True)[0]
27    voice_name = voice_name or TTS.QWEN_VOICE
28    raw_texts = markdown_to_text(texts)
29    num_token = count_token(raw_texts, model)
30    if num_token < TTS.QWEN_INPUT_TOKEN_LIMIT:
31        return await qwen_tts_real(texts, model, voice_name)
32    # split
33    text_list = await smart_split(texts, chars_per_string=TTS.QWEN_SPLIT_LENGTH, mode=ParseMode.DISABLED)
34    resp = await asyncio.gather(*[qwen_tts_real(text, model, voice_name) for text in text_list])
35    wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
36    merge_wav([r["voice"] for r in resp], wav_path)
37    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
38
39
40async def qwen_tts_real(texts: str, model: str, voice_name: str) -> dict:
41    """Qwen TTS.
42
43    Args:
44        return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
45
46    Returns:
47        {"url": str, "duration": int, "voice_name": str, "model": str}
48    """
49    save_path = Path("/non-exist")
50    duration = 0
51    for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
52        try:
53            logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
54            response = await hx_req(
55                "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation",
56                "POST",
57                headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
58                json_data={"model": model, "input": {"text": markdown_to_text(texts), "voice": voice_name}},
59                proxy=PROXY.ALI,
60                check_keys=["output.audio.url", "usage.output_tokens"],
61            )
62            url = glom(response, "output.audio.url", default="")
63            save_path = await download_file(url, proxy=PROXY.ALI)
64            duration = glom(response, "usage.output_tokens", default=0) / 50  # 1s = 50 tokens
65        except Exception as e:
66            logger.error(e)
67        return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
68    return {}
69
70
71def merge_wav(wav_paths: list[str], save_path: str | Path):
72    """Merge wav files into single one."""
73    # detect sample rate, channels, subtype
74    with sf.SoundFile(wav_paths[0], "r") as f:
75        samplerate = f.samplerate
76        channels = f.channels
77        subtype = f.subtype
78
79    # write one by one
80    with sf.SoundFile(save_path, "w", samplerate=samplerate, channels=channels, subtype=subtype) as outfile:
81        for file_path in wav_paths:
82            with sf.SoundFile(file_path, "r") as infile:
83                if infile.samplerate != samplerate or infile.channels != channels:
84                    logger.warning(f"{file_path}的参数不匹配")
85                    continue
86                for block in infile.blocks(blocksize=1024):
87                    outfile.write(block)
88
89
90def count_token(texts: str, model: str = "") -> int:
91    tokenizer = get_tokenizer(model)
92    tokens = tokenizer.encode(texts)
93    return len(tokens)