Commit b172bf1
Changed files (8)
src/ai/images/gemini.py
@@ -35,8 +35,14 @@ async def gemini_image_generation(
gemini_proxy: str | None = PROXY.GOOGLE,
max_reference_img: int = 0,
**kwargs,
-) -> bool:
- """Get Gemini Image Generation."""
+) -> dict[str, Message | bool]:
+ """Get Gemini Image Generation.
+
+ Return:
+ dict[str, Message | bool]:
+ {"success": True} # Generated image successfully
+ {"progress": Message} # Failed to generate image
+ """
status_msg = kwargs.get("progress") or await message.reply(f"π**{model_name}**:\nζ£ε¨ηζεΎε...", quote=True)
for api_key in strings_list(gemini_api_keys, shuffle=True):
@@ -63,12 +69,12 @@ async def gemini_image_generation(
if media:
await send2tg(client, message, texts=f"π**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
await delete_message(status_msg)
- return True
+ return {"success": True}
await modify_progress(status_msg, text=f"β{response.model_dump()}", force_update=True, **kwargs)
except Exception as e:
logger.error(f"Gemini API error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
src/ai/images/openai_img.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import asyncio
import base64
from pathlib import Path
from typing import Literal
@@ -39,8 +38,14 @@ async def openai_image_generation(
max_height: int = int(1e16),
max_size: int = int(1e32),
**kwargs,
-) -> bool:
- """Get OpenAI Image Generation."""
+) -> dict[str, Message | bool]:
+ """Get OpenAI Image Generation.
+
+ Return:
+ dict[str, Message | bool]:
+ {"success": True} # Generated image successfully
+ {"progress": Message} # Failed to generate image
+ """
status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\nζ£ε¨ηζεΎε...", quote=True)
try:
openai_client = {}
@@ -51,7 +56,7 @@ async def openai_image_generation(
prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
if not prompt:
await modify_progress(status_msg, text=f"β**{model_name}**:\nθ―·ζδΎζη€Ίθ―", force_update=True, **kwargs)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
params = {}
if literal_eval(generate_config):
params |= literal_eval(generate_config)
@@ -66,9 +71,7 @@ async def openai_image_generation(
except Exception as e:
logger.error(f"OpenAI client setup error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- await asyncio.sleep(10)
- await delete_message(status_msg)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
resp = {}
for api_key in strings_list(api_keys, shuffle=True):
try:
@@ -85,12 +88,11 @@ async def openai_image_generation(
await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
await delete_message(status_msg)
- return True
+ return {"success": True}
except Exception as e:
logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
await modify_progress(status_msg, text=f"β{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
- await delete_message(status_msg)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
src/ai/images/post.py
@@ -40,14 +40,20 @@ async def http_post_image_generation(
max_height: int = int(1e16),
max_size: int = int(1e32),
**kwargs,
-) -> bool:
- """Get HTTP Post Image Generation."""
+) -> dict[str, Message | bool]:
+ """Get HTTP Post Image Generation.
+
+ Return:
+ dict[str, Message | bool]:
+ {"success": True} # Generated image successfully
+ {"progress": Message} # Failed to generate image
+ """
status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\nζ£ε¨ηζεΎε...", quote=True)
try:
prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
if not prompt:
await modify_progress(status_msg, text=f"β**{model_name}**:\nθ―·ζδΎζη€Ίθ―", force_update=True, **kwargs)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
params = {}
api_paths = api_paths or {}
if headers:
@@ -68,7 +74,7 @@ async def http_post_image_generation(
resp = await hx_req(**params)
if error := resp.get("hx_error"):
await modify_progress(status_msg, text=f"β**{model_name}**:\n{error}", force_update=True, **kwargs)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
image_urls: list[str] = []
metadata = ""
@@ -85,12 +91,11 @@ async def http_post_image_generation(
await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
await delete_message(status_msg)
- return True
+ return {"success": True}
except Exception as e:
logger.error(f"HTTP Post Image Generation error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- await delete_message(status_msg)
- return False
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
def extract_metadata(response: dict) -> str:
src/ai/texts/gemini.py
@@ -73,6 +73,7 @@ async def gemini_chat_completion(
if resp.get("texts"):
sent_messages.extend(resp.get("sent_messages", []))
return {
+ "success": True,
"texts": resp["texts"],
"thoughts": resp["thoughts"],
"prefix": prefix,
@@ -82,7 +83,7 @@ async def gemini_chat_completion(
except Exception as e:
logger.error(f"Gemini API error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- return {}
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
async def single_api_generate_content(
src/ai/texts/openai_chat.py
@@ -72,7 +72,8 @@ async def openai_chat_completions(
except Exception as e:
logger.error(f"OpenAI client setup error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- return {}
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
+
for api_key in strings_list(openai_api_keys, shuffle=True):
try:
openai_client |= {"base_url": openai_base_url, "api_key": api_key}
@@ -89,11 +90,18 @@ async def openai_chat_completions(
**kwargs,
)
if resp.get("texts") or resp.get("tool_name"):
- return resp | {"prefix": prefix, "model_name": model_name, "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)]}
+ resp |= {
+ "success": True,
+ "prefix": prefix,
+ "model_name": model_name,
+ "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
+ }
+ resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
+ return resp
except Exception as e:
logger.error(f"OpenAI API error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- return {}
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
async def single_api_chat_completions(
src/ai/texts/openai_response.py
@@ -69,7 +69,7 @@ async def openai_responses_api(
openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
except Exception as e:
logger.error(f"OpenAI client setup error: {e}")
- return {}
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
for api_key in strings_list(openai_api_keys, shuffle=True):
try:
@@ -126,6 +126,7 @@ async def openai_responses_api(
silent=silent,
)
return {
+ "success": True,
"texts": resp["texts"],
"thoughts": resp["thoughts"],
"response_id": resp["response_id"],
@@ -136,7 +137,7 @@ async def openai_responses_api(
except Exception as e:
logger.error(f"OpenAI API error: {e}")
await modify_progress(status_msg, text=f"β{e}", force_update=True, **kwargs)
- return {}
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
async def single_api_response(
src/ai/texts/tool_call.py
@@ -99,7 +99,7 @@ Use the `web_search` tool to access up-to-date information from the web or when
"openai_system_prompt": system_prompt,
}
resp = await openai_chat_completions(client, message, **kwargs)
- kwargs["progress"] = kwargs.get("progress", glom(resp, "sent_messages.0", default=None))
+ kwargs["progress"] = kwargs.get("progress") or resp.get("progress") or glom(resp, "sent_messages.0", default=None)
while resp.get("tool_name"):
tool_name = resp["tool_name"].strip()
tool_args = literal_eval(resp.get("tool_args", "{}"))
@@ -111,12 +111,15 @@ Use the `web_search` tool to access up-to-date information from the web or when
if not kwargs["openai_tools"]:
break
resp = await openai_chat_completions(client, message, **kwargs)
- result_texts = glom(kwargs, "openai_contexts.0.content", default="").strip()
- return {
- "openai_system_prompt": result_texts.removeprefix(system_prompt).strip(), # add tool results to system prompt
- "openai_tools": None, # disable tools after tool call
- "progress": kwargs["progress"],
- }
+ if texts := glom(kwargs, "openai_contexts.0.content", default="").strip().removeprefix(system_prompt).strip():
+ return {
+ "success": True,
+ "openai_system_prompt": texts, # add tool results to system prompt
+ "openai_tools": None, # disable tools after tool call
+ "progress": kwargs["progress"],
+ }
+ status_msg = kwargs.get("progress") or resp.get("progress")
+ return {"progress": status_msg} if isinstance(status_msg, Message) else {}
def add_search_results(contexts: list[dict], search_results: list[dict]) -> list[dict]:
src/ai/main.py
@@ -35,26 +35,42 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
model_configs = await get_text_model_configs(this_msg)
if not model_configs:
return {}
+
+ def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
+ """Handle API response.
+
+ Update current_kwargs with progress message if available.
+ """
+ if isinstance(resp.get("progress"), Message):
+ current_kwargs["progress"] = resp["progress"]
+ if resp.get("success", False):
+ return resp
+ return None
+
for model_config in model_configs:
+ api_type = model_config["api_type"]
params: dict = model_config | kwargs
- match model_config["api_type"]:
- case "gemini":
- if res := await gemini_chat_completion(client, message, **params):
- return res
- case "openai_responses":
- if res := await openai_responses_api(client, message, **params):
- return res
- case "openai_chat":
- if params.get("openai_enable_tool_call", True):
- tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
- for tool_config in tool_configs:
- tool_params = params | tool_config
- if tool_results := await get_tool_call_results(client, message, **tool_params):
- params |= tool_results
- break
- if res := await openai_chat_completions(client, message, **params):
- return res
- return {}
+ res = {}
+ if api_type == "gemini":
+ res = await gemini_chat_completion(client, message, **params)
+ elif api_type == "openai_responses":
+ res = await openai_responses_api(client, message, **params)
+ elif api_type == "openai_chat":
+ if model_config.get("openai_enable_tool_call", True):
+ tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
+ for tool_config in tool_configs:
+ tool_params = params | tool_config
+ tool_results = await get_tool_call_results(client, message, **tool_params)
+ if isinstance(tool_results.get("progress"), Message):
+ kwargs["progress"] = tool_results["progress"]
+ params["progress"] = tool_results["progress"]
+ if tool_results.get("success", False):
+ params |= tool_results
+ break
+ res = await openai_chat_completions(client, message, **params)
+ if successful_res := handle_response(res, kwargs):
+ return successful_res
+ return res
async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
@@ -72,17 +88,19 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
model_configs = await get_image_model_configs(this_msg)
if not model_configs:
return
+
+ params: dict = {"success": False, "progress": None}
for model_config in model_configs:
- match model_config["api_type"]:
- case "openai":
- if await openai_image_generation(client, message, **model_config):
- return
- case "post":
- if await http_post_image_generation(client, message, **model_config):
- return
- case "gemini":
- if await gemini_image_generation(client, message, **model_config):
- return
+ api_type = model_config["api_type"]
+ params |= model_config
+ if api_type == "openai":
+ params |= await openai_image_generation(client, message, **params)
+ elif api_type == "post":
+ params |= await http_post_image_generation(client, message, **params)
+ elif api_type == "gemini":
+ params |= await gemini_image_generation(client, message, **params)
+ if params.get("success"):
+ return
async def ai_video_generation(client: Client, message: Message, **kwargs) -> None: