Commit 781af41
Changed files (5)
src
src/ai/texts/models.py
@@ -5,7 +5,7 @@ import re
from loguru import logger
from pyrogram.types import Message
-from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, text_generation_docs
+from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
from config import AI, PREFIX
from database.kv import get_cf_kv
from messages.utils import startswith_prefix
@@ -156,7 +156,7 @@ async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bo
if config := kv.get(model_alias, {}):
common_config = config.get("common_config", {})
- return [common_config | model_config for model_config in config.get("models", [])]
+ return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
if not fallback_to_default:
return []
@@ -202,7 +202,7 @@ async def get_config_by_model_name(model_name: str) -> list[dict]:
continue
common_config = config.get("common_config", {})
for model in config.get("models", []):
- model_config = common_config | model
+ model_config = deep_merge(common_config, model)
if model_config.get("model_name", "") == model_name:
model_configs.append(model_config)
return model_configs
src/ai/texts/openai_response.py
@@ -12,7 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
from pyrogram.types import Message, ReplyParameters
from ai.texts.contexts import get_openai_response_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
from config import AI, PROXY, TEXT_LENGTH
from database.r2 import set_cf_r2
from messages.parser import get_thread_id
@@ -66,9 +66,9 @@ async def openai_responses_api(
if literal_eval(openai_client_config):
openai_client |= literal_eval(openai_client_config)
if literal_eval(openai_default_headers):
- openai_client |= {"default_headers": literal_eval(openai_default_headers)}
+ openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
if openai_proxy:
- openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
+ openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
except Exception as e:
logger.error(f"OpenAI client setup error: {e}")
return {"progress": status_msg} if isinstance(status_msg, Message) else {}
src/ai/main.py
@@ -15,7 +15,7 @@ from ai.texts.models import get_config_by_model_alias, get_text_model_configs
from ai.texts.openai_chat import openai_chat_completions
from ai.texts.openai_response import openai_responses_api
from ai.texts.tool_call import get_tool_call_results
-from ai.utils import img_generation_docs, video_generation_docs
+from ai.utils import deep_merge, img_generation_docs, video_generation_docs
from ai.videos.models import get_video_model_configs
from ai.videos.post import http_post_video_generation
from config import AI, PREFIX
@@ -50,7 +50,7 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
for model_config in model_configs:
api_type = model_config["api_type"]
- params: dict = model_config | kwargs
+ params = deep_merge(model_config, kwargs)
res = {}
if api_type == "gemini":
res = await gemini_chat_completion(client, message, **params)
@@ -62,13 +62,13 @@ async def ai_text_generation(client: Client, message: Message, **kwargs) -> dict
if model_config.get("openai_enable_tool_call", True):
tool_configs = await get_config_by_model_alias(AI.TOOL_CALL_MODEL_ALIAS, fallback_to_default=False)
for tool_config in tool_configs:
- tool_params = params | tool_config
+ tool_params = deep_merge(params, tool_config)
tool_results = await get_tool_call_results(client, message, **tool_params)
if isinstance(tool_results.get("progress"), Message):
kwargs["progress"] = tool_results["progress"]
params["progress"] = tool_results["progress"]
if tool_results.get("success", False):
- params |= tool_results
+ params = deep_merge(params, tool_results)
break
res = await openai_chat_completions(client, message, **params)
if successful_res := handle_response(res, kwargs):
@@ -95,7 +95,7 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> Non
params: dict = {"success": False, "progress": None}
for model_config in model_configs:
api_type = model_config["api_type"]
- params |= model_config
+ params = deep_merge(params, model_config)
if api_type == "openai":
params |= await openai_image_generation(client, message, **params)
elif api_type == "post":
src/ai/utils.py
@@ -4,6 +4,8 @@ import ast
import contextlib
import json
import re
+from collections.abc import Mapping
+from copy import deepcopy
from datetime import datetime
from anthropic import AsyncAnthropic, DefaultAioHttpClient
@@ -215,3 +217,23 @@ async def clean_anthropic_files():
if delta.total_seconds() > AI.ANTHROPIC_FILES_TTL:
logger.debug(f"Delete Anthropic file: {f.filename}")
await anthropic.beta.files.delete(file_id=f.id)
+
+
+def deep_merge(base_dict: dict, *update_dicts: dict) -> dict:
+ """Deep merge multiple dicts into a new dict.
+
+ Args:
+ base_dict: The base dictionary to merge into
+ *update_dicts: Dictionaries to merge into the base
+
+ Returns:
+ A new dictionary with all values merged
+ """
+ result = deepcopy(base_dict)
+ for update_dict in update_dicts:
+ for k, v in update_dict.items():
+ if isinstance(v, Mapping) and isinstance(result.get(k), Mapping):
+ result[k] = deep_merge(result[k], v)
+ else:
+ result[k] = v
+ return result
src/asr/corrector.py
@@ -97,10 +97,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
Message(id=rand_number(), chat=Chat(id=rand_number()), text=Str(f"{PREFIX.AI_TEXT_GENERATION} @{corrector_model} {texts}")),
openai_responses_config={
"instructions": SYSTEM_PROMPT,
- "max_output_tokens": 65536,
- "extra_body": {"thinking": {"type": "enabled"}},
- "tools": [{"type": "web_search", "max_keyword": 5, "limit": 10}],
- "max_tool_calls": 10,
"text": {
"format": {
"type": "json_schema",
@@ -112,9 +108,6 @@ async def asr_corrector(inputs: str, reference: str | None = None, corrector_mod
},
},
gemini_generate_content_config={
- "max_output_tokens": 65536,
- "thinking_config": {"include_thoughts": True, "thinking_level": "high"},
- "tools": [{"google_search": {}}, {"code_execution": {}}],
"system_instruction": SYSTEM_PROMPT,
"responseMimeType": "application/json",
"responseJsonSchema": JSON_SCHEMA,