Commit b15f5fb

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-10-23 08:55:46
feat(tts): support configurable default TTS engine via env variable
1 parent 5e492f0
src/tts/edge.py
@@ -6,8 +6,7 @@ import anyio
 from loguru import logger
 
 from asr.utils import audio_duration
-from config import CAPTION_LENGTH, DOWNLOAD_DIR, TTS
-from messages.utils import blockquote
+from config import DOWNLOAD_DIR, TTS
 from networking import hx_req
 from utils import markdown_to_text, rand_string
 
@@ -18,7 +17,7 @@ async def edge_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
     https://github.com/wangwangit/tts
 
     Returns:
-        {"voice": str, "duration": int, "caption": str}
+        {"voice": str, "duration": int, "voice_name":str, "model":str}
     """
     model = model or TTS.EDGE_MODEL
     voice_name = voice_name or TTS.EDGE_VOICE
@@ -37,5 +36,4 @@ async def edge_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
     save_path = Path(DOWNLOAD_DIR) / f"{rand_string(8)}.mp3"
     async with await anyio.open_file(save_path, "wb") as f:
         await f.write(response["content"])
-    caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
-    return {"voice": save_path, "duration": audio_duration(save_path), "caption": caption}
+    return {"voice": save_path, "duration": audio_duration(save_path), "voice_name": voice_name, "model": model}
src/tts/engines.py
@@ -3,6 +3,8 @@
 import random
 from collections import defaultdict
 
+from config import TTS
+
 ENGINES = [
     # Gemini
     {"name": "Achernar", "desc": "Soft", "engine": "gemini", "sex": "male"},
@@ -200,8 +202,7 @@ def get_tts_config(texts: str) -> tuple[str, str, str, str]:
     Returns:
         (voice_name, engine, model, texts)
     """
-    # use gemini by default
-    engine = "gemini"
+    engine = TTS.DEFAULT_ENGINE.lower()
     if not texts.startswith("@"):
         return "", engine, "", texts
     if texts.startswith(("@男", "@male")):
src/tts/gemini.py
@@ -12,9 +12,9 @@ from loguru import logger
 from pyrogram.enums import ParseMode
 from pyrogram.types import Message
 
-from config import CAPTION_LENGTH, DOWNLOAD_DIR, GEMINI, TTS
+from config import DOWNLOAD_DIR, GEMINI, TTS
 from llm.hooks import hook_gemini_httpoptions
-from messages.utils import blockquote, smart_split
+from messages.utils import smart_split
 from utils import markdown_to_text, rand_string, strings_list
 
 
@@ -24,7 +24,7 @@ async def gemini_tts(message: Message, texts: str, model: str = "", voice_name:
     https://ai.google.dev/gemini-api/docs/speech-generation
 
     Returns:
-        {"voice": str, "duration": int, "caption": str}
+        {"voice": str, "duration": int, "voice_name": str, "model": str}
     """
     model = model or TTS.GEMINI_MODEL
     voice_name = voice_name or TTS.GEMINI_VOICE
@@ -38,8 +38,7 @@ async def gemini_tts(message: Message, texts: str, model: str = "", voice_name:
     wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
     combined_data = b"".join([r["voice"] for r in resp])
     save_wave_file(wav_path, combined_data)
-    caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
-    return {"voice": wav_path, "duration": calculate_duration(combined_data), "caption": caption}
+    return {"voice": wav_path, "duration": calculate_duration(combined_data), "voice_name": voice_name, "model": model}
 
 
 async def gemini_tts_real(message: Message, texts: str, model: str, voice_name: str, *, return_bytes: bool = True) -> dict:
@@ -49,7 +48,7 @@ async def gemini_tts_real(message: Message, texts: str, model: str, voice_name:
         return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
 
     Returns:
-        {"voice": str or bytes, "duration": int, "caption": str}
+        {"voice": str or bytes, "duration": int, "voice_name": str, "model": str}
     """
     for api_key in strings_list(GEMINI.API_KEY, shuffle=True):
         try:
@@ -71,12 +70,11 @@ async def gemini_tts_real(message: Message, texts: str, model: str, voice_name:
             )
             await app.aio.aclose()
             if data := glom(response, "candidates.0.content.parts.0.inline_data.data", default=None):
-                caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
                 if return_bytes:
-                    return {"voice": data, "duration": calculate_duration(data), "caption": caption}
+                    return {"voice": data, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
                 wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
                 save_wave_file(wav_path, data)
-                return {"voice": wav_path, "duration": calculate_duration(data), "caption": caption}
+                return {"voice": wav_path, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
         except Exception as e:
             logger.error(e)
     return {}
src/tts/qwen.py
@@ -9,8 +9,8 @@ from glom import glom
 from loguru import logger
 from pyrogram.enums import ParseMode
 
-from config import CAPTION_LENGTH, DOWNLOAD_DIR, TTS
-from messages.utils import blockquote, smart_split
+from config import DOWNLOAD_DIR, TTS
+from messages.utils import smart_split
 from networking import download_file, hx_req
 from utils import markdown_to_text, rand_string, strings_list
 
@@ -21,7 +21,7 @@ async def qwen_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
     https://help.aliyun.com/zh/model-studio/qwen-tts
 
     Returns:
-        {"voice": str, "duration": int, "caption": str}
+        {"voice": str, "duration": int, "voice_name": str, "model": str}
     """
     model = model or strings_list(TTS.QWEN_MODEL, shuffle=True)[0]
     voice_name = voice_name or TTS.QWEN_VOICE
@@ -34,8 +34,7 @@ async def qwen_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
     resp = await asyncio.gather(*[qwen_tts_real(text, model, voice_name) for text in text_list])
     wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
     merge_wav([r["voice"] for r in resp], wav_path)
-    caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
-    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "caption": caption}
+    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
 
 
 async def qwen_tts_real(texts: str, model: str, voice_name: str) -> dict:
@@ -45,11 +44,10 @@ async def qwen_tts_real(texts: str, model: str, voice_name: str) -> dict:
         return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
 
     Returns:
-        {"url": str, "duration": int, "caption": str}
+        {"url": str, "duration": int, "voice_name": str, "model": str}
     """
     save_path = Path("/non-exist")
     duration = 0
-    caption = ""
     for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
         try:
             logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
@@ -64,10 +62,9 @@ async def qwen_tts_real(texts: str, model: str, voice_name: str) -> dict:
             url = glom(response, "output.audio.url", default="")
             save_path = await download_file(url, proxy=TTS.ALI_PROXY)
             duration = glom(response, "usage.output_tokens", default=0) / 50  # 1s = 50 tokens
-            caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
         except Exception as e:
             logger.error(e)
-        return {"voice": save_path, "duration": duration, "caption": caption}
+        return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
     return {}
 
 
src/tts/sambert.py
@@ -11,8 +11,8 @@ from glom import glom
 from loguru import logger
 from pyrogram.enums import ParseMode
 
-from config import CAPTION_LENGTH, DOWNLOAD_DIR, TTS
-from messages.utils import blockquote, smart_split
+from config import DOWNLOAD_DIR, TTS
+from messages.utils import smart_split
 from tts.engines import LIMIT_FOR_MODEL, get_random_one
 from utils import markdown_to_text, rand_string, strings_list
 
@@ -23,7 +23,7 @@ async def sambert_tts(texts: str, model: str = "", voice_name: str = "") -> dict
     https://help.aliyun.com/zh/model-studio/text-to-speech
 
     Returns:
-        {"voice": str, "duration": int, "caption": str}
+        {"voice": str, "duration": int, "voice_name": str, "model": str}
     """
     if not model:
         config = get_random_one(engine="sambert")
@@ -38,8 +38,7 @@ async def sambert_tts(texts: str, model: str = "", voice_name: str = "") -> dict
     resp = await asyncio.gather(*[sambert_tts_real(text, model, voice_name) for text in text_list])
     wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
     merge_wav([r["voice"] for r in resp], wav_path)
-    caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
-    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "caption": caption}
+    return {"voice": wav_path, "duration": sum([r["duration"] for r in resp]), "voice_name": voice_name, "model": model}
 
 
 async def sambert_tts_real(texts: str, model: str, voice_name: str) -> dict:
@@ -49,11 +48,10 @@ async def sambert_tts_real(texts: str, model: str, voice_name: str) -> dict:
         return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
 
     Returns:
-        {"url": str, "duration": int, "caption": str}
+        {"url": str, "duration": int, "voice_name": str, "model": str}
     """
     save_path = Path("/non-exist")
     duration = 0
-    caption = ""
     for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
         try:
             logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
@@ -65,10 +63,9 @@ async def sambert_tts_real(texts: str, model: str, voice_name: str) -> dict:
             duration = 0
             if timestamps := response.get_timestamps():
                 duration = glom(timestamps, "-1.end_time", default=0) / 1000
-            caption = f"🗣音色: {voice_name}\n🤖引擎: {model}\n{blockquote(texts[: CAPTION_LENGTH - 20])}"
         except Exception as e:
             logger.error(e)
-        return {"voice": save_path, "duration": duration, "caption": caption}
+        return {"voice": save_path, "duration": duration, "voice_name": voice_name, "model": model}
     return {}
 
 
src/tts/tts.py
@@ -5,30 +5,30 @@ from pathlib import Path
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import DOWNLOAD_DIR, PREFIX, TTS
+from config import CAPTION_LENGTH, DOWNLOAD_DIR, PREFIX, TTS
 from messages.parser import parse_msg
 from messages.sender import send2tg
-from messages.utils import blockquote, equal_prefix, set_reaction, startswith_prefix
+from messages.utils import blockquote, equal_prefix, set_reaction, smart_split, startswith_prefix
 from tts.edge import edge_tts
 from tts.engines import get_tts_config, list_engines
 from tts.gemini import gemini_tts
 from tts.qwen import qwen_tts
 from tts.sambert import sambert_tts
-from utils import read_text
+from utils import markdown_to_text, read_text
 
 HELP = f"""🗣**文字转语音**
 使用说明:
 1. `{PREFIX.TTS}` + 文字或txt文件
 2. `{PREFIX.TTS}` 回复 `文字消息` 或 `txt文件消息`
-3. `{PREFIX.TTS} @音色名` 可以指定音色, 默认音色: {TTS.GEMINI_VOICE}
+3. `{PREFIX.TTS} @音色名` 可以指定音色, 默认音色: {TTS.EDGE_VOICE}
 
 特殊用法:
 - `{PREFIX.TTS}` + @男 或 @male: 随机一款男声
 - `{PREFIX.TTS}` + @女 或 @female: 随机一款女声
+- `{PREFIX.TTS} @edge`: 随机一款MS Edge音色
 - `{PREFIX.TTS} @gemini`: 随机一款Gemini音色
 - `{PREFIX.TTS} @qwen`: 随机一款通义千问音色
 - `{PREFIX.TTS} @sambert`: 随机一款阿里Sambert音色
-- `{PREFIX.TTS} @edge`: 随机一款MS Edge音色
 {blockquote(list_engines())}
 """
 
@@ -72,10 +72,15 @@ async def text_to_speech(client: Client, message: Message, **kwargs):
         raise ValueError(msg)
 
     path = Path(resp.get("voice", ""))
-    if path.is_file():
-        resp["duration"] = round(resp["duration"])
-        await message.reply_voice(**resp, quote=True)
-        await set_reaction(client, reaction_msg, reaction="")
-    else:
+    if not path.is_file():
         await set_reaction(client, reaction_msg, reaction="💔")
+        return
+
+    duration = round(resp["duration"])
+    caption = f"🗣音色: {resp['voice_name']}\n🤖引擎: {resp['model']}\n{blockquote(markdown_to_text(texts))}"
+    if len(await smart_split(caption, CAPTION_LENGTH)) == 1:
+        await message.reply_voice(path.as_posix(), duration=duration, caption=caption, quote=True)
+    else:
+        await message.reply_voice(path.as_posix(), duration=duration, caption=f"🗣音色: {resp['voice_name']}\n🤖引擎: {resp['model']}", quote=True)
+    await set_reaction(client, reaction_msg, reaction="")
     path.unlink(missing_ok=True)
src/config.py
@@ -346,6 +346,7 @@ class FAVORITE:
 
 class TTS:
     # TTS related
+    DEFAULT_ENGINE = os.getenv("TTS_DEFAULT_ENGINE", "edge")  # edge, gemini, qwen, sambert
     GEMINI_MODEL = os.getenv("TTS_GEMINI_MODEL", "gemini-2.5-flash-preview-tts")
     GEMINI_INPUT_TOKEN_LIMIT = int(os.getenv("TTS_GEMINI_INPUT_TOKEN_LIMIT", "8192"))  # token limit of the tts model
     GEMINI_SPLIT_LENGTH = int(os.getenv("TTS_GEMINI_SPLIT_LENGTH", "8192"))  # split token limit of the tts model
@@ -359,7 +360,7 @@ class TTS:
     SAMBERT_MODEL = os.getenv("TTS_SAMBERT_MODEL", "ramdom")  # comma separated models for load balance. use "random" to randomly choose a model
     SAMBERT_LENGTH_LIMIT = int(os.getenv("TTS_SAMBERT_LENGTH_LIMIT", "20000"))  # token limit of the tts model
     EDGE_DOMAIN = os.getenv("TTS_EDGE_DOMAIN", "https://tts.wangwangit.com")
-    EDGE_VOICE = os.getenv("TTS_EDGE_VOICE", "zh-CN-XiaoxiaoNeural")
+    EDGE_VOICE = os.getenv("TTS_EDGE_VOICE", "晓晓")
     EDGE_MODEL = os.getenv("TTS_EDGE_MODEL", "zh-CN-XiaoxiaoNeural")
     EDGE_PROXY = os.getenv("TTS_EDGE_PROXY", None)