main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4from pathlib import Path
5
6import soundfile as sf
7from dashscope import get_tokenizer
8from glom import glom
9from loguru import logger
10from pyrogram.enums import ParseMode
11
12from config import DOWNLOAD_DIR, PROXY, TTS
13from messages.utils import smart_split
14from networking import download_file, hx_req
15from utils import markdown_to_text, rand_string, strings_list
16
17
18async def qwen_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
19 """Qwen TTS.
20
21 https://help.aliyun.com/zh/model-studio/qwen-tts
22
23 Returns:
24 {"voice": str, "duration": int, "voice_name": str, "model": str}
25 """
26 model = model or strings_list(TTS.QWEN_MODEL, shuffle=True)[0]
27 voice_name = voice_name or TTS.QWEN_VOICE
28 raw_texts = markdown_to_text(texts)
29 num_token = count_token(raw_texts, model)
30 if num_token < TTS.QWEN_INPUT_TOKEN_LIMIT:
31 return await qwen_tts_real(texts, model, voice_name)
32 # split
33 text_list = await smart_split(texts, chars_per_string=TTS.QWEN_SPLIT_LENGTH, mode=ParseMode.DISABLED)
34 resp = await asyncio.gather(*[qwen_tts_real(text, model, voice_name) for text in text_list])
35 wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
36 merge_wav([r["voice"] for r in resp], wav_path)
37 return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
38
39
40async def qwen_tts_real(texts: str, model: str, voice_name: str) -> dict:
41 """Qwen TTS.
42
43 Args:
44 return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
45
46 Returns:
47 {"url": str, "duration": int, "voice_name": str, "model": str}
48 """
49 save_path = Path("/non-exist")
50 duration = 0
51 for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
52 try:
53 logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
54 response = await hx_req(
55 "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation",
56 "POST",
57 headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
58 json_data={"model": model, "input": {"text": markdown_to_text(texts), "voice": voice_name}},
59 proxy=PROXY.ALI,
60 check_keys=["output.audio.url", "usage.output_tokens"],
61 )
62 url = glom(response, "output.audio.url", default="")
63 save_path = await download_file(url, proxy=PROXY.ALI)
64 duration = glom(response, "usage.output_tokens", default=0) / 50 # 1s = 50 tokens
65 except Exception as e:
66 logger.error(e)
67 return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
68 return {}
69
70
71def merge_wav(wav_paths: list[str], save_path: str | Path):
72 """Merge wav files into single one."""
73 # detect sample rate, channels, subtype
74 with sf.SoundFile(wav_paths[0], "r") as f:
75 samplerate = f.samplerate
76 channels = f.channels
77 subtype = f.subtype
78
79 # write one by one
80 with sf.SoundFile(save_path, "w", samplerate=samplerate, channels=channels, subtype=subtype) as outfile:
81 for file_path in wav_paths:
82 with sf.SoundFile(file_path, "r") as infile:
83 if infile.samplerate != samplerate or infile.channels != channels:
84 logger.warning(f"{file_path}的参数不匹配")
85 continue
86 for block in infile.blocks(blocksize=1024):
87 outfile.write(block)
88
89
90def count_token(texts: str, model: str = "") -> int:
91 tokenizer = get_tokenizer(model)
92 tokens = tokenizer.encode(texts)
93 return len(tokens)