main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import wave
5from pathlib import Path
6
7from glom import glom
8from google import genai
9from google.genai import types
10from google.genai.types import HttpOptions
11from loguru import logger
12from pyrogram.enums import ParseMode
13
14from ai.utils import literal_eval
15from config import AI, DOWNLOAD_DIR, PROXY, TTS
16from messages.utils import smart_split
17from utils import markdown_to_text, rand_string, strings_list
18
19
20async def gemini_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
21 """Gemini TTS.
22
23 https://ai.google.dev/gemini-api/docs/speech-generation
24
25 Returns:
26 {"voice": str, "duration": int, "voice_name": str, "model": str}
27 """
28 model = model or TTS.GEMINI_MODEL
29 voice_name = voice_name or TTS.GEMINI_VOICE
30 raw_texts = markdown_to_text(texts)
31 num_token = await count_token(raw_texts, model)
32 if num_token < TTS.GEMINI_INPUT_TOKEN_LIMIT:
33 return await gemini_tts_real(texts, model, voice_name, return_bytes=False)
34 # split
35 text_list = await smart_split(texts, chars_per_string=TTS.GEMINI_SPLIT_LENGTH, mode=ParseMode.DISABLED)
36 resp = await asyncio.gather(*[gemini_tts_real(text, model, voice_name) for text in text_list])
37 wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
38 combined_data = b"".join([r["voice"] for r in resp])
39 save_wave_file(wav_path, combined_data)
40 return {"voice": wav_path, "duration": calculate_duration(combined_data), "voice_name": voice_name, "model": model}
41
42
43async def gemini_tts_real(texts: str, model: str, voice_name: str, *, return_bytes: bool = True) -> dict:
44 """Gemini TTS.
45
46 Args:
47 return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
48
49 Returns:
50 {"voice": str or bytes, "duration": int, "voice_name": str, "model": str}
51 """
52 for api_key in strings_list(AI.GEMINI_API_KEYS, shuffle=True):
53 try:
54 logger.debug(f"TTS via {model}, proxy={PROXY.GOOGLE}, voice: {voice_name}, texts: {texts}")
55 app = genai.Client(
56 api_key=api_key,
57 http_options=HttpOptions(
58 base_url=AI.GEMINI_BASE_URL,
59 headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
60 async_client_args={"proxy": PROXY.GOOGLE},
61 ),
62 )
63 response = await app.aio.models.generate_content(
64 model=model,
65 contents=markdown_to_text(texts),
66 config=types.GenerateContentConfig(
67 response_modalities=["AUDIO"],
68 speech_config=types.SpeechConfig(
69 voice_config=types.VoiceConfig(
70 prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name),
71 ),
72 ),
73 ),
74 )
75 await app.aio.aclose()
76 if data := glom(response, "candidates.0.content.parts.0.inline_data.data", default=None):
77 if return_bytes:
78 return {"voice": data, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
79 wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
80 save_wave_file(wav_path, data)
81 return {"voice": wav_path, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
82 except Exception as e:
83 logger.error(e)
84 return {}
85
86
87def save_wave_file(path: Path | str, pcm: bytes, channels: int = 1, rate: float = 24000, sample_width: int = 2):
88 """Save PCM data to a wave file."""
89 path = Path(path).as_posix()
90 with wave.open(path, "wb") as wf:
91 wf.setnchannels(channels)
92 wf.setsampwidth(sample_width)
93 wf.setframerate(rate)
94 wf.writeframes(pcm)
95
96
97def calculate_duration(pcm: bytes, channels: int = 1, rate: float = 24000, sample_width: int = 2) -> int:
98 # calculate total frames
99 bytes_per_frame = sample_width * channels
100 if bytes_per_frame == 0:
101 return 0
102 num_frames = len(pcm) / bytes_per_frame
103 return round(num_frames / rate) # duration seconds
104
105
106async def count_token(texts: str, model_id: str = "") -> int:
107 model = model_id or TTS.GEMINI_MODEL
108 app = genai.Client(
109 api_key=strings_list(AI.GEMINI_API_KEYS, shuffle=True)[0],
110 http_options=HttpOptions(
111 base_url=AI.GEMINI_BASE_URL,
112 headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
113 async_client_args={"proxy": PROXY.GOOGLE},
114 ),
115 )
116 response = await app.aio.models.count_tokens(model=model, contents=texts)
117 return response.total_tokens or 0