main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import re
  4
  5from pyrogram.client import Client
  6from pyrogram.types import Message
  7
  8from ai.images.gemini import gemini_image_generation
  9from ai.images.models import get_image_model_configs
 10from ai.images.openai_img import openai_image_generation
 11from ai.images.post import http_post_image_generation
 12from ai.texts.claude import anthropic_responses
 13from ai.texts.gemini import gemini_chat_completion
 14from ai.texts.models import get_config_by_model_alias, get_text_model_configs
 15from ai.texts.openai_chat import openai_chat_completions
 16from ai.texts.openai_response import openai_responses_api
 17from ai.texts.tool_call import get_tool_call_results
 18from ai.utils import deep_merge, img_generation_docs, video_generation_docs
 19from ai.videos.models import get_video_model_configs
 20from ai.videos.post import http_post_video_generation
 21from config import AI, PREFIX
 22from messages.sender import send2tg
 23from messages.utils import startswith_prefix
 24
 25
 26async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
 27    texts = str(message.content).strip()
 28    this_msg = message
 29    prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
 30    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
 31    if not prompt:
 32        if this_msg.media is not None:  # `/ai` + [meida]
 33            pass
 34        elif message.reply_to_message:  # `/ai` reply to [message]
 35            message = this_msg.reply_to_message
 36    model_configs = await get_text_model_configs(this_msg)
 37    if not model_configs:
 38        return {}
 39
 40    def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
 41        """Handle API response.
 42
 43        Update current_kwargs with progress message if available.
 44        """
 45        if isinstance(resp.get("progress"), Message):
 46            current_kwargs["progress"] = resp["progress"]
 47        if resp.get("success", False):
 48            return resp
 49        return None
 50
 51    for model_config in model_configs:
 52        api_type = model_config["api_type"]
 53        params = deep_merge(model_config, kwargs)
 54        res = {}
 55        if api_type == "gemini":
 56            res = await gemini_chat_completion(client, message, **params)
 57        elif api_type == "anthropic":
 58            res = await anthropic_responses(client, message, **params)
 59        elif api_type == "openai_responses":
 60            res = await openai_responses_api(client, message, **params)
 61        elif api_type == "openai_chat":
 62            if model_config.get("openai_enable_tool_call", True):
 63                tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
 64                for tool_config in tool_configs:
 65                    tool_params = deep_merge(params, tool_config)
 66                    tool_results = await get_tool_call_results(client, message, **tool_params)
 67                    if isinstance(tool_results.get("progress"), Message):
 68                        kwargs["progress"] = tool_results["progress"]
 69                        params["progress"] = tool_results["progress"]
 70                    if tool_results.get("success", False):
 71                        params = deep_merge(params, tool_results)
 72                        break
 73            res = await openai_chat_completions(client, message, **params)
 74        if successful_res := handle_response(res, kwargs):
 75            return successful_res
 76    return res
 77
 78
 79async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
 80    if not startswith_prefix(message.content, PREFIX.AI_IMG_GENERATION):
 81        return
 82    texts = str(message.content).strip()
 83    this_msg = message
 84    prompt = texts.removeprefix(PREFIX.AI_IMG_GENERATION).strip()
 85    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
 86    if not prompt:
 87        if not message.reply_to_message:
 88            await send2tg(client, message, texts=await img_generation_docs(), **kwargs)
 89            return
 90        message = this_msg.reply_to_message
 91    model_configs = await get_image_model_configs(this_msg)
 92    if not model_configs:
 93        return
 94
 95    params: dict = {"success": False, "progress": None}
 96    for model_config in model_configs:
 97        api_type = model_config["api_type"]
 98        params = deep_merge(params, model_config)
 99        if api_type == "openai":
100            params |= await openai_image_generation(client, message, **params)
101        elif api_type == "post":
102            params |= await http_post_image_generation(client, message, **params)
103        elif api_type == "gemini":
104            params |= await gemini_image_generation(client, message, **params)
105        if params.get("success"):
106            return
107
108
109async def ai_video_generation(client: Client, message: Message, **kwargs) -> None:
110    if not startswith_prefix(message.content, PREFIX.AI_VIDEO_GENERATION):
111        return
112    texts = str(message.content).strip()
113    this_msg = message
114    prompt = texts.removeprefix(PREFIX.AI_VIDEO_GENERATION).strip()
115    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
116    if not prompt:
117        if not message.reply_to_message:
118            await send2tg(client, message, texts=await video_generation_docs(), **kwargs)
119            return
120        message = this_msg.reply_to_message
121    model_configs = await get_video_model_configs(this_msg)
122    if not model_configs:
123        return
124    for model_config in model_configs:
125        if await http_post_video_generation(client, message, **model_config):
126            return