main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import re
4
5from pyrogram.client import Client
6from pyrogram.types import Message
7
8from ai.images.gemini import gemini_image_generation
9from ai.images.models import get_image_model_configs
10from ai.images.openai_img import openai_image_generation
11from ai.images.post import http_post_image_generation
12from ai.texts.claude import anthropic_responses
13from ai.texts.gemini import gemini_chat_completion
14from ai.texts.models import get_config_by_model_alias, get_text_model_configs, reorder_model_configs
15from ai.texts.openai_chat import openai_chat_completions
16from ai.texts.openai_response import openai_responses_api
17from ai.texts.tool_call import get_tool_call_results
18from ai.utils import deep_merge, img_generation_docs, video_generation_docs
19from ai.videos.models import get_video_model_configs
20from ai.videos.post import http_post_video_generation
21from config import AI, PREFIX
22from messages.sender import send2tg
23from messages.utils import startswith_prefix
24
25
26async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict:
27 texts = str(message.content).strip()
28 this_msg = message
29 prompt = texts.removeprefix(PREFIX.AI_TEXT_GENERATION).strip()
30 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
31 if not prompt:
32 if this_msg.media is not None: # `/ai` + [meida]
33 pass
34 elif message.reply_to_message: # `/ai` reply to [message]
35 message = this_msg.reply_to_message
36 model_configs = await get_text_model_configs(this_msg)
37 if not model_configs:
38 return {}
39
40 model_configs = await reorder_model_configs(client, message, model_configs, kwargs)
41
42 def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
43 """Handle API response.
44
45 Update current_kwargs with progress message if available.
46 """
47 if isinstance(resp.get("progress"), Message):
48 current_kwargs["progress"] = resp["progress"]
49 if resp.get("success", False):
50 return resp
51 return None
52
53 for model_config in model_configs:
54 api_type = model_config["api_type"]
55 params = deep_merge(model_config, kwargs)
56 res = {}
57 if api_type == "gemini":
58 res = await gemini_chat_completion(client, message, **params)
59 elif api_type == "anthropic":
60 res = await anthropic_responses(client, message, **params)
61 elif api_type == "openai_responses":
62 res = await openai_responses_api(client, message, **params)
63 elif api_type == "openai_chat":
64 if model_config.get("openai_enable_tool_call", True):
65 tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
66 for tool_config in tool_configs:
67 tool_params = deep_merge(params, tool_config)
68 tool_results = await get_tool_call_results(client, message, **tool_params)
69 if isinstance(tool_results.get("progress"), Message):
70 kwargs["progress"] = tool_results["progress"]
71 params["progress"] = tool_results["progress"]
72 if tool_results.get("success", False):
73 params = deep_merge(params, tool_results)
74 break
75 res = await openai_chat_completions(client, message, **params)
76 if successful_res := handle_response(res, kwargs):
77 return successful_res
78 return res
79
80
81async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
82 if not startswith_prefix(message.content, PREFIX.AI_IMG_GENERATION):
83 return
84 texts = str(message.content).strip()
85 this_msg = message
86 prompt = texts.removeprefix(PREFIX.AI_IMG_GENERATION).strip()
87 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
88 if not prompt:
89 if not message.reply_to_message:
90 await send2tg(client, message, texts=await img_generation_docs(), **kwargs)
91 return
92 message = this_msg.reply_to_message
93 model_configs = await get_image_model_configs(this_msg)
94 if not model_configs:
95 return
96
97 params: dict = {"success": False, "progress": None}
98 for model_config in model_configs:
99 api_type = model_config["api_type"]
100 params = deep_merge(params, model_config)
101 if api_type == "openai":
102 params |= await openai_image_generation(client, message, **params)
103 elif api_type == "post":
104 params |= await http_post_image_generation(client, message, **params)
105 elif api_type == "gemini":
106 params |= await gemini_image_generation(client, message, **params)
107 if params.get("success"):
108 return
109
110
111async def ai_video_generation(client: Client, message: Message, **kwargs) -> None:
112 if not startswith_prefix(message.content, PREFIX.AI_VIDEO_GENERATION):
113 return
114 texts = str(message.content).strip()
115 this_msg = message
116 prompt = texts.removeprefix(PREFIX.AI_VIDEO_GENERATION).strip()
117 prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
118 if not prompt:
119 if not message.reply_to_message:
120 await send2tg(client, message, texts=await video_generation_docs(), **kwargs)
121 return
122 message = this_msg.reply_to_message
123 model_configs = await get_video_model_configs(this_msg)
124 if not model_configs:
125 return
126 for model_config in model_configs:
127 if await http_post_video_generation(client, message, **model_config):
128 return