main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import wave
  5from pathlib import Path
  6
  7from glom import glom
  8from google import genai
  9from google.genai import types
 10from google.genai.types import HttpOptions
 11from loguru import logger
 12from pyrogram.enums import ParseMode
 13
 14from ai.utils import literal_eval
 15from config import AI, DOWNLOAD_DIR, PROXY, TTS
 16from messages.utils import smart_split
 17from utils import markdown_to_text, rand_string, strings_list
 18
 19
 20async def gemini_tts(texts: str, model: str = "", voice_name: str = "") -> dict:
 21    """Gemini TTS.
 22
 23    https://ai.google.dev/gemini-api/docs/speech-generation
 24
 25    Returns:
 26        {"voice": str, "duration": int, "voice_name": str, "model": str}
 27    """
 28    model = model or TTS.GEMINI_MODEL
 29    voice_name = voice_name or TTS.GEMINI_VOICE
 30    raw_texts = markdown_to_text(texts)
 31    num_token = await count_token(raw_texts, model)
 32    if num_token < TTS.GEMINI_INPUT_TOKEN_LIMIT:
 33        return await gemini_tts_real(texts, model, voice_name, return_bytes=False)
 34    # split
 35    text_list = await smart_split(texts, chars_per_string=TTS.GEMINI_SPLIT_LENGTH, mode=ParseMode.DISABLED)
 36    resp = await asyncio.gather(*[gemini_tts_real(text, model, voice_name) for text in text_list])
 37    wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
 38    combined_data = b"".join([r["voice"] for r in resp])
 39    save_wave_file(wav_path, combined_data)
 40    return {"voice": wav_path, "duration": calculate_duration(combined_data), "voice_name": voice_name, "model": model}
 41
 42
 43async def gemini_tts_real(texts: str, model: str, voice_name: str, *, return_bytes: bool = True) -> dict:
 44    """Gemini TTS.
 45
 46    Args:
 47        return_bytes (bool, optional): If True, return audio bytes. Defaults to False.
 48
 49    Returns:
 50        {"voice": str or bytes, "duration": int, "voice_name": str, "model": str}
 51    """
 52    for api_key in strings_list(AI.GEMINI_API_KEYS, shuffle=True):
 53        try:
 54            logger.debug(f"TTS via {model}, proxy={PROXY.GOOGLE}, voice: {voice_name}, texts: {texts}")
 55            app = genai.Client(
 56                api_key=api_key,
 57                http_options=HttpOptions(
 58                    base_url=AI.GEMINI_BASE_URL,
 59                    headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
 60                    async_client_args={"proxy": PROXY.GOOGLE},
 61                ),
 62            )
 63            response = await app.aio.models.generate_content(
 64                model=model,
 65                contents=markdown_to_text(texts),
 66                config=types.GenerateContentConfig(
 67                    response_modalities=["AUDIO"],
 68                    speech_config=types.SpeechConfig(
 69                        voice_config=types.VoiceConfig(
 70                            prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name),
 71                        ),
 72                    ),
 73                ),
 74            )
 75            await app.aio.aclose()
 76            if data := glom(response, "candidates.0.content.parts.0.inline_data.data", default=None):
 77                if return_bytes:
 78                    return {"voice": data, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
 79                wav_path = Path(DOWNLOAD_DIR) / f"{rand_string(16)}.wav"
 80                save_wave_file(wav_path, data)
 81                return {"voice": wav_path, "duration": calculate_duration(data), "voice_name": voice_name, "model": model}
 82        except Exception as e:
 83            logger.error(e)
 84    return {}
 85
 86
 87def save_wave_file(path: Path | str, pcm: bytes, channels: int = 1, rate: float = 24000, sample_width: int = 2):
 88    """Save PCM data to a wave file."""
 89    path = Path(path).as_posix()
 90    with wave.open(path, "wb") as wf:
 91        wf.setnchannels(channels)
 92        wf.setsampwidth(sample_width)
 93        wf.setframerate(rate)
 94        wf.writeframes(pcm)
 95
 96
 97def calculate_duration(pcm: bytes, channels: int = 1, rate: float = 24000, sample_width: int = 2) -> int:
 98    # calculate total frames
 99    bytes_per_frame = sample_width * channels
100    if bytes_per_frame == 0:
101        return 0
102    num_frames = len(pcm) / bytes_per_frame
103    return round(num_frames / rate)  # duration seconds
104
105
106async def count_token(texts: str, model_id: str = "") -> int:
107    model = model_id or TTS.GEMINI_MODEL
108    app = genai.Client(
109        api_key=strings_list(AI.GEMINI_API_KEYS, shuffle=True)[0],
110        http_options=HttpOptions(
111            base_url=AI.GEMINI_BASE_URL,
112            headers=literal_eval(AI.GEMINI_DEFAULT_HEADERS),
113            async_client_args={"proxy": PROXY.GOOGLE},
114        ),
115    )
116    response = await app.aio.models.count_tokens(model=model, contents=texts)
117    return response.total_tokens or 0