main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import re
4
5from pyrogram.client import Client
6from pyrogram.types import Message
7
8from ai.images.gemini import gemini_image_generation
9from ai.images.models import get_image_model_configs
10from ai.images.openai_img import openai_image_generation
11from ai.images.post import http_post_image_generation
12from ai.texts.claude import anthropic_responses
13from ai.texts.gemini import gemini_chat_completion
14from ai.texts.models import get_config_by_model_alias, get_text_model_configs
15from ai.texts.openai_chat import openai_chat_completions
16from ai.texts.openai_response import openai_responses_api
17from ai.texts.tool_call import get_tool_call_results
18from ai.utils import deep_merge, img_generation_docs, video_generation_docs
19from ai.videos.models import get_video_model_configs
20from ai.videos.post import http_post_video_generation
21from config import AI, PREFIX
22from messages.sender import send2tg
23from messages.utils import startswith_prefix
24
25
26async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
27 texts = str(message.content).strip()
28 this_msg = message
29 prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
30 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
31 if not prompt:
32 if this_msg.media is not None: # `/ai` + [meida]
33 pass
34 elif message.reply_to_message: # `/ai` reply to [message]
35 message = this_msg.reply_to_message
36 model_configs = await get_text_model_configs(this_msg)
37 if not model_configs:
38 return {}
39
40 def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
41 """Handle API response.
42
43 Update current_kwargs with progress message if available.
44 """
45 if isinstance(resp.get("progress"), Message):
46 current_kwargs["progress"] = resp["progress"]
47 if resp.get("success", False):
48 return resp
49 return None
50
51 for model_config in model_configs:
52 api_type = model_config["api_type"]
53 params = deep_merge(model_config, kwargs)
54 res = {}
55 if api_type == "gemini":
56 res = await gemini_chat_completion(client, message, **params)
57 elif api_type == "anthropic":
58 res = await anthropic_responses(client, message, **params)
59 elif api_type == "openai_responses":
60 res = await openai_responses_api(client, message, **params)
61 elif api_type == "openai_chat":
62 if model_config.get("openai_enable_tool_call", True):
63 tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
64 for tool_config in tool_configs:
65 tool_params = deep_merge(params, tool_config)
66 tool_results = await get_tool_call_results(client, message, **tool_params)
67 if isinstance(tool_results.get("progress"), Message):
68 kwargs["progress"] = tool_results["progress"]
69 params["progress"] = tool_results["progress"]
70 if tool_results.get("success", False):
71 params = deep_merge(params, tool_results)
72 break
73 res = await openai_chat_completions(client, message, **params)
74 if successful_res := handle_response(res, kwargs):
75 return successful_res
76 return res
77
78
79async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
80 if not startswith_prefix(message.content, PREFIX.AI_IMG_GENERATION):
81 return
82 texts = str(message.content).strip()
83 this_msg = message
84 prompt = texts.removeprefix(PREFIX.AI_IMG_GENERATION).strip()
85 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
86 if not prompt:
87 if not message.reply_to_message:
88 await send2tg(client, message, texts=await img_generation_docs(), **kwargs)
89 return
90 message = this_msg.reply_to_message
91 model_configs = await get_image_model_configs(this_msg)
92 if not model_configs:
93 return
94
95 params: dict = {"success": False, "progress": None}
96 for model_config in model_configs:
97 api_type = model_config["api_type"]
98 params = deep_merge(params, model_config)
99 if api_type == "openai":
100 params |= await openai_image_generation(client, message, **params)
101 elif api_type == "post":
102 params |= await http_post_image_generation(client, message, **params)
103 elif api_type == "gemini":
104 params |= await gemini_image_generation(client, message, **params)
105 if params.get("success"):
106 return
107
108
109async def ai_video_generation(client: Client, message: Message, **kwargs) -> None:
110 if not startswith_prefix(message.content, PREFIX.AI_VIDEO_GENERATION):
111 return
112 texts = str(message.content).strip()
113 this_msg = message
114 prompt = texts.removeprefix(PREFIX.AI_VIDEO_GENERATION).strip()
115 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
116 if not prompt:
117 if not message.reply_to_message:
118 await send2tg(client, message, texts=await video_generation_docs(), **kwargs)
119 return
120 message = this_msg.reply_to_message
121 model_configs = await get_video_model_configs(this_msg)
122 if not model_configs:
123 return
124 for model_config in model_configs:
125 if await http_post_video_generation(client, message, **model_config):
126 return