Commit 781af41

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-04-14 03:06:15
chore(ai): use `deep_merge` to merge dicts
1 parent 1064501
src/ai/texts/models.py
@@ -5,7 +5,7 @@ import re
 from loguru import logger
 from pyrogram.types import Message
 
-from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, text_generation_docs
+from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
 from config import AI, PREFIX
 from database.kv import get_cf_kv
 from messages.utils import startswith_prefix
@@ -156,7 +156,7 @@ async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bo
 
     if config := kv.get(model_alias, {}):
         common_config = config.get("common_config", {})
-        return [common_config | model_config for model_config in config.get("models", [])]
+        return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
 
     if not fallback_to_default:
         return []
@@ -202,7 +202,7 @@ async def get_config_by_model_name(model_name: str) -> list[dict]:
             continue
         common_config = config.get("common_config", {})
         for model in config.get("models", []):
-            model_config = common_config | model
+            model_config = deep_merge(common_config, model)
             if model_config.get("model_name", "") == model_name:
                 model_configs.append(model_config)
     return model_configs
src/ai/texts/openai_response.py
@@ -12,7 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 from pyrogram.types import Message, ReplyParameters
 
 from ai.texts.contexts import get_openai_response_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
 from config import AI, PROXY, TEXT_LENGTH
 from database.r2 import set_cf_r2
 from messages.parser import get_thread_id
@@ -66,9 +66,9 @@ async def openai_responses_api(
         if literal_eval(openai_client_config):
             openai_client |= literal_eval(openai_client_config)
         if literal_eval(openai_default_headers):
-            openai_client |= {"default_headers": literal_eval(openai_default_headers)}
+            openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
         if openai_proxy:
-            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
+            openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
     except Exception as e:
         logger.error(f"OpenAI client setup error: {e}")
         return {"progress": status_msg} if isinstance(status_msg, Message) else {}
src/ai/main.py
@@ -15,7 +15,7 @@ from ai.texts.models import get_config_by_model_alias, get_text_model_configs
 from ai.texts.openai_chat import openai_chat_completions
 from ai.texts.openai_response import openai_responses_api
 from ai.texts.tool_call import get_tool_call_results
-from ai.utils import img_generation_docs, video_generation_docs
+from ai.utils import deep_merge, img_generation_docs, video_generation_docs
 from ai.videos.models import get_video_model_configs
 from ai.videos.post import http_post_video_generation
 from config import AI, PREFIX
@@ -50,7 +50,7 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
 
     for model_config in model_configs:
         api_type = model_config["api_type"]
-        params: dict = model_config | kwargs
+        params = deep_merge(model_config, kwargs)
         res = {}
         if api_type == "gemini":
             res = await gemini_chat_completion(client, message, **params)
@@ -62,13 +62,13 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
             if model_config.get("openai_enable_tool_call", True):
                 tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
                 for tool_config in tool_configs:
-                    tool_params = params | tool_config
+                    tool_params = deep_merge(params, tool_config)
                     tool_results = await get_tool_call_results(client, message, **tool_params)
                     if isinstance(tool_results.get("progress"), Message):
                         kwargs["progress"] = tool_results["progress"]
                         params["progress"] = tool_results["progress"]
                     if tool_results.get("success", False):
-                        params |= tool_results
+                        params = deep_merge(params, tool_results)
                         break
             res = await openai_chat_completions(client, message, **params)
         if successful_res := handle_response(res, kwargs):
@@ -95,7 +95,7 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
     params: dict = {"success": False, "progress": None}
     for model_config in model_configs:
         api_type = model_config["api_type"]
-        params |= model_config
+        params = deep_merge(params, model_config)
         if api_type == "openai":
             params |= await openai_image_generation(client, message, **params)
         elif api_type == "post":
src/ai/utils.py
@@ -4,6 +4,8 @@ import ast
 import contextlib
 import json
 import re
+from collections.abc import Mapping
+from copy import deepcopy
 from datetime import datetime
 
 from anthropic import AsyncAnthropic, DefaultAioHttpClient
@@ -215,3 +217,23 @@ async def clean_anthropic_files():
             if delta.total_seconds() > AI.ANTHROPIC_FILES_TTL:
                 logger.debug(f"Delete Anthropic file: {f.filename}")
                 await anthropic.beta.files.delete(file_id=f.id)
+
+
+def deep_merge(base_dict: dict, *update_dicts: dict) -> dict:
+    """Deep merge multiple dicts into a new dict.
+
+    Args:
+        base_dict: The base dictionary to merge into
+        *update_dicts: Dictionaries to merge into the base
+
+    Returns:
+        A new dictionary with all values merged
+    """
+    result = deepcopy(base_dict)
+    for update_dict in update_dicts:
+        for k, v in update_dict.items():
+            if isinstance(v, Mapping) and isinstance(result.get(k), Mapping):
+                result[k] = deep_merge(result[k], v)
+            else:
+                result[k] = v
+    return result
src/asr/corrector.py
@@ -97,10 +97,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
         Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
         openai_responses_config={
             "instructions": SYSTEM_PROMPT,
-            "max_output_tokens": 65536,
-            "extra_body": {"thinking": {"type": "enabled"}},
-            "tools": [{"type": "web_search", "max_keyword": 5, "limit": 10}],
-            "max_tool_calls": 10,
             "text": {
                 "format": {
                     "type": "json_schema",
@@ -112,9 +108,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
             },
         },
         gemini_generate_content_config={
-            "max_output_tokens": 65536,
-            "thinking_config": {"include_thoughts": True, "thinking_level": "high"},
-            "tools": [{"google_search": {}}, {"code_execution": {}}],
             "system_instruction": SYSTEM_PROMPT,
             "responseMimeType": "application/json",
             "responseJsonSchema": JSON_SCHEMA,