Commit b172bf1

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-02-14 02:06:47
chore(ai): maintain a single progress message across all API calls
1 parent 6072580
src/ai/images/gemini.py
@@ -35,8 +35,14 @@ async def gemini_image_generation(
     gemini_proxy: str | None = PROXY.GOOGLE,
     max_reference_img: int = 0,
     **kwargs,
-) -> bool:
-    """Get Gemini Image Generation."""
+) -> dict[str, Message | bool]:
+    """Get Gemini Image Generation.
+
+    Return:
+        dict[str, Message | bool]:
+            {"success": True}  # Generated image successfully
+            {"progress": Message}  # Failed to generate image
+    """
     status_msg = kwargs.get("progress") or await message.reply(f"🍌**{model_name}**:\nζ­£εœ¨η”Ÿζˆε›Ύεƒ...", quote=True)
 
     for api_key in strings_list(gemini_api_keys, shuffle=True):
@@ -63,12 +69,12 @@ async def gemini_image_generation(
             if media:
                 await send2tg(client, message, texts=f"🍌**{model_name}**:\n{texts}", media=media, caption_above=True, **kwargs)
                 await delete_message(status_msg)
-                return True
+                return {"success": True}
             await modify_progress(status_msg, text=f"❌{response.model_dump()}", force_update=True, **kwargs)
         except Exception as e:
             logger.error(f"Gemini API error: {e}")
             await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-    return False
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 async def get_gemini_contexts(client: Client, message: Message, model_name: str, *, max_reference_img: int = 0) -> list[dict]:
src/ai/images/openai_img.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import asyncio
 import base64
 from pathlib import Path
 from typing import Literal
@@ -39,8 +38,14 @@ async def openai_image_generation(
     max_height: int = int(1e16),
     max_size: int = int(1e32),
     **kwargs,
-) -> bool:
-    """Get OpenAI Image Generation."""
+) -> dict[str, Message | bool]:
+    """Get OpenAI Image Generation.
+
+    Return:
+    dict[str, Message | bool]:
+        {"success": True}  # Generated image successfully
+        {"progress": Message}  # Failed to generate image
+    """
     status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\nζ­£εœ¨η”Ÿζˆε›Ύεƒ...", quote=True)
     try:
         openai_client = {}
@@ -51,7 +56,7 @@ async def openai_image_generation(
         prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
         if not prompt:
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提瀺词", force_update=True, **kwargs)
-            return False
+            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
         params = {}
         if literal_eval(generate_config):
             params |= literal_eval(generate_config)
@@ -66,9 +71,7 @@ async def openai_image_generation(
     except Exception as e:
         logger.error(f"OpenAI client setup error: {e}")
         await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-        await asyncio.sleep(10)
-        await delete_message(status_msg)
-        return False
+        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
     resp = {}
     for api_key in strings_list(api_keys, shuffle=True):
         try:
@@ -85,12 +88,11 @@ async def openai_image_generation(
                 await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
                 await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
                 await delete_message(status_msg)
-                return True
+                return {"success": True}
         except Exception as e:
             logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
             await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
-    await delete_message(status_msg)
-    return False
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 async def download_generated_images(response: dict, proxy: str | None = None) -> list[dict]:
src/ai/images/post.py
@@ -40,14 +40,20 @@ async def http_post_image_generation(
     max_height: int = int(1e16),
     max_size: int = int(1e32),
     **kwargs,
-) -> bool:
-    """Get HTTP Post Image Generation."""
+) -> dict[str, Message | bool]:
+    """Get HTTP Post Image Generation.
+
+    Return:
+        dict[str, Message | bool]:
+            {"success": True}  # Generated image successfully
+            {"progress": Message}  # Failed to generate image
+    """
     status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\nζ­£εœ¨η”Ÿζˆε›Ύεƒ...", quote=True)
     try:
         prompt, reference_images = await get_image_contexts(client, message, max_reference_img=max_reference_img)
         if not prompt:
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提瀺词", force_update=True, **kwargs)
-            return False
+            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
         params = {}
         api_paths = api_paths or {}
         if headers:
@@ -68,7 +74,7 @@ async def http_post_image_generation(
         resp = await hx_req(**params)
         if error := resp.get("hx_error"):
             await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
-            return False
+            return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
         image_urls: list[str] = []
         metadata = ""
@@ -85,12 +91,11 @@ async def http_post_image_generation(
             await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
             await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
             await delete_message(status_msg)
-            return True
+            return {"success": True}
     except Exception as e:
         logger.error(f"HTTP Post Image Generation error: {e}")
         await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-    await delete_message(status_msg)
-    return False
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 def extract_metadata(response: dict) -> str:
src/ai/texts/gemini.py
@@ -73,6 +73,7 @@ async def gemini_chat_completion(
             if resp.get("texts"):
                 sent_messages.extend(resp.get("sent_messages", []))
                 return {
+                    "success": True,
                     "texts": resp["texts"],
                     "thoughts": resp["thoughts"],
                     "prefix": prefix,
@@ -82,7 +83,7 @@ async def gemini_chat_completion(
         except Exception as e:
             logger.error(f"Gemini API error: {e}")
             await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-    return {}
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 async def single_api_generate_content(
src/ai/texts/openai_chat.py
@@ -72,7 +72,8 @@ async def openai_chat_completions(
     except Exception as e:
         logger.error(f"OpenAI client setup error: {e}")
         await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-        return {}
+        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
+
     for api_key in strings_list(openai_api_keys, shuffle=True):
         try:
             openai_client |= {"base_url": openai_base_url, "api_key": api_key}
@@ -89,11 +90,18 @@ async def openai_chat_completions(
                 **kwargs,
             )
             if resp.get("texts") or resp.get("tool_name"):
-                return resp | {"prefix": prefix, "model_name": model_name, "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)]}
+                resp |= {
+                    "success": True,
+                    "prefix": prefix,
+                    "model_name": model_name,
+                    "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
+                }
+                resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
+                return resp
         except Exception as e:
             logger.error(f"OpenAI API error: {e}")
             await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-    return {}
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 async def single_api_chat_completions(
src/ai/texts/openai_response.py
@@ -69,7 +69,7 @@ async def openai_responses_api(
             openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
     except Exception as e:
         logger.error(f"OpenAI client setup error: {e}")
-        return {}
+        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
     for api_key in strings_list(openai_api_keys, shuffle=True):
         try:
@@ -126,6 +126,7 @@ async def openai_responses_api(
                         silent=silent,
                     )
             return {
+                "success": True,
                 "texts": resp["texts"],
                 "thoughts": resp["thoughts"],
                 "response_id": resp["response_id"],
@@ -136,7 +137,7 @@ async def openai_responses_api(
         except Exception as e:
             logger.error(f"OpenAI API error: {e}")
             await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
-    return {}
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 async def single_api_response(
src/ai/texts/tool_call.py
@@ -99,7 +99,7 @@ Use the `web_search` tool to access up-to-date information from the web or when
         "openai_system_prompt": system_prompt,
     }
     resp = await openai_chat_completions(client, message, **kwargs)
-    kwargs["progress"] = kwargs.get("progress", glom(resp, "sent_messages.0", default=None))
+    kwargs["progress"] = kwargs.get("progress") or resp.get("progress") or glom(resp, "sent_messages.0", default=None)
     while resp.get("tool_name"):
         tool_name = resp["tool_name"].strip()
         tool_args = literal_eval(resp.get("tool_args", "{}"))
@@ -111,12 +111,15 @@ Use the `web_search` tool to access up-to-date information from the web or when
         if not kwargs["openai_tools"]:
             break
         resp = await openai_chat_completions(client, message, **kwargs)
-    result_texts = glom(kwargs, "openai_contexts.0.content", default="").strip()
-    return {
-        "openai_system_prompt": result_texts.removeprefix(system_prompt).strip(),  # add tool results to system prompt
-        "openai_tools": None,  # disable tools after tool call
-        "progress": kwargs["progress"],
-    }
+    if texts := glom(kwargs, "openai_contexts.0.content", default="").strip().removeprefix(system_prompt).strip():
+        return {
+            "success": True,
+            "openai_system_prompt": texts,  # add tool results to system prompt
+            "openai_tools": None,  # disable tools after tool call
+            "progress": kwargs["progress"],
+        }
+    status_msg = kwargs.get("progress") or resp.get("progress")
+    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 
 
 def add_search_results(contexts: list[dict], search_results: list[dict]) -> list[dict]:
src/ai/main.py
@@ -35,26 +35,42 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
     model_configs = await get_text_model_configs(this_msg)
     if not model_configs:
         return {}
+
+    def handle_response(resp: dict, current_kwargs: dict) -> dict | None:
+        """Handle API response.
+
+        Update current_kwargs with progress message if available.
+        """
+        if isinstance(resp.get("progress"), Message):
+            current_kwargs["progress"] = resp["progress"]
+        if resp.get("success", False):
+            return resp
+        return None
+
     for model_config in model_configs:
+        api_type = model_config["api_type"]
         params: dict = model_config | kwargs
-        match model_config["api_type"]:
-            case "gemini":
-                if res := await gemini_chat_completion(client, message, **params):
-                    return res
-            case "openai_responses":
-                if res := await openai_responses_api(client, message, **params):
-                    return res
-            case "openai_chat":
-                if params.get("openai_enable_tool_call", True):
-                    tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
-                    for tool_config in tool_configs:
-                        tool_params = params | tool_config
-                        if tool_results := await get_tool_call_results(client, message, **tool_params):
-                            params |= tool_results
-                            break
-                if res := await openai_chat_completions(client, message, **params):
-                    return res
-    return {}
+        res = {}
+        if api_type == "gemini":
+            res = await gemini_chat_completion(client, message, **params)
+        elif api_type == "openai_responses":
+            res = await openai_responses_api(client, message, **params)
+        elif api_type == "openai_chat":
+            if model_config.get("openai_enable_tool_call", True):
+                tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
+                for tool_config in tool_configs:
+                    tool_params = params | tool_config
+                    tool_results = await get_tool_call_results(client, message, **tool_params)
+                    if isinstance(tool_results.get("progress"), Message):
+                        kwargs["progress"] = tool_results["progress"]
+                        params["progress"] = tool_results["progress"]
+                    if tool_results.get("success", False):
+                        params |= tool_results
+                        break
+            res = await openai_chat_completions(client, message, **params)
+        if successful_res := handle_response(res, kwargs):
+            return successful_res
+    return res
 
 
 async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
@@ -72,17 +88,19 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
     model_configs = await get_image_model_configs(this_msg)
     if not model_configs:
         return
+
+    params: dict = {"success": False, "progress": None}
     for model_config in model_configs:
-        match model_config["api_type"]:
-            case "openai":
-                if await openai_image_generation(client, message, **model_config):
-                    return
-            case "post":
-                if await http_post_image_generation(client, message, **model_config):
-                    return
-            case "gemini":
-                if await gemini_image_generation(client, message, **model_config):
-                    return
+        api_type = model_config["api_type"]
+        params |= model_config
+        if api_type == "openai":
+            params |= await openai_image_generation(client, message, **params)
+        elif api_type == "post":
+            params |= await http_post_image_generation(client, message, **params)
+        elif api_type == "gemini":
+            params |= await gemini_image_generation(client, message, **params)
+        if params.get("success"):
+            return
 
 
 async def ai_video_generation(client: Client, message: Message, **kwargs) -> None: