main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import re
  4
  5from pyrogram.client import Client
  6from pyrogram.types import Message
  7
  8from ai.images.gemini import gemini_image_generation
  9from ai.images.models import get_image_model_configs
 10from ai.images.openai_img import openai_image_generation
 11from ai.images.post import http_post_image_generation
 12from ai.texts.claude import anthropic_responses
 13from ai.texts.gemini import gemini_chat_completion
 14from ai.texts.models import get_config_by_model_alias, get_text_model_configs, reorder_model_configs
 15from ai.texts.openai_chat import openai_chat_completions
 16from ai.texts.openai_response import openai_responses_api
 17from ai.texts.tool_call import get_tool_call_results
 18from ai.utils import deep_merge, img_generation_docs, video_generation_docs
 19from ai.videos.models import get_video_model_configs
 20from ai.videos.post import http_post_video_generation
 21from config import AI, PREFIX
 22from messages.sender import send2tg
 23from messages.utils import startswith_prefix
 24
 25
 26async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
 27    texts = str(message.content).strip()
 28    this_msg = message
 29    prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
 30    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
 31    if not prompt:
 32        if this_msg.media is not None:  # `/ai` + [meida]
 33            pass
 34        elif message.reply_to_message:  # `/ai` reply to [message]
 35            message = this_msg.reply_to_message
 36    model_configs = await get_text_model_configs(this_msg)
 37    if not model_configs:
 38        return {}
 39
 40    model_configs = await reorder_model_configs(client, message, model_configs, kwargs)
 41
 42    def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
 43        """Handle API response.
 44
 45        Update current_kwargs with progress message if available.
 46        """
 47        if isinstance(resp.get("progress"), Message):
 48            current_kwargs["progress"] = resp["progress"]
 49        if resp.get("success", False):
 50            return resp
 51        return None
 52
 53    for model_config in model_configs:
 54        api_type = model_config["api_type"]
 55        params = deep_merge(model_config, kwargs)
 56        res = {}
 57        if api_type == "gemini":
 58            res = await gemini_chat_completion(client, message, **params)
 59        elif api_type == "anthropic":
 60            res = await anthropic_responses(client, message, **params)
 61        elif api_type == "openai_responses":
 62            res = await openai_responses_api(client, message, **params)
 63        elif api_type == "openai_chat":
 64            if model_config.get("openai_enable_tool_call", True):
 65                tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
 66                for tool_config in tool_configs:
 67                    tool_params = deep_merge(params, tool_config)
 68                    tool_results = await get_tool_call_results(client, message, **tool_params)
 69                    if isinstance(tool_results.get("progress"), Message):
 70                        kwargs["progress"] = tool_results["progress"]
 71                        params["progress"] = tool_results["progress"]
 72                    if tool_results.get("success", False):
 73                        params = deep_merge(params, tool_results)
 74                        break
 75            res = await openai_chat_completions(client, message, **params)
 76        if successful_res := handle_response(res, kwargs):
 77            return successful_res
 78    return res
 79
 80
 81async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
 82    if not startswith_prefix(message.content, PREFIX.AI_IMG_GENERATION):
 83        return
 84    texts = str(message.content).strip()
 85    this_msg = message
 86    prompt = texts.removeprefix(PREFIX.AI_IMG_GENERATION).strip()
 87    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
 88    if not prompt:
 89        if not message.reply_to_message:
 90            await send2tg(client, message, texts=await img_generation_docs(), **kwargs)
 91            return
 92        message = this_msg.reply_to_message
 93    model_configs = await get_image_model_configs(this_msg)
 94    if not model_configs:
 95        return
 96
 97    params: dict = {"success": False, "progress": None}
 98    for model_config in model_configs:
 99        api_type = model_config["api_type"]
100        params = deep_merge(params, model_config)
101        if api_type == "openai":
102            params |= await openai_image_generation(client, message, **params)
103        elif api_type == "post":
104            params |= await http_post_image_generation(client, message, **params)
105        elif api_type == "gemini":
106            params |= await gemini_image_generation(client, message, **params)
107        if params.get("success"):
108            return
109
110
111async def ai_video_generation(client: Client, message: Message, **kwargs) -> None:
112    if not startswith_prefix(message.content, PREFIX.AI_VIDEO_GENERATION):
113        return
114    texts = str(message.content).strip()
115    this_msg = message
116    prompt = texts.removeprefix(PREFIX.AI_VIDEO_GENERATION).strip()
117    prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
118    if not prompt:
119        if not message.reply_to_message:
120            await send2tg(client, message, texts=await video_generation_docs(), **kwargs)
121            return
122        message = this_msg.reply_to_message
123    model_configs = await get_video_model_configs(this_msg)
124    if not model_configs:
125        return
126    for model_config in model_configs:
127        if await http_post_video_generation(client, message, **model_config):
128            return