Commit 7b31448
Changed files (1)
src
ai
texts
src/ai/texts/models.py
@@ -9,6 +9,7 @@ from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
from config import AI, PREFIX
from database.kv import get_cf_kv
from messages.utils import startswith_prefix
+from utils import strings_list
# ruff: noqa: RUF002
@@ -153,10 +154,16 @@ async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bo
model_config
"""
kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
-
- if config := kv.get(model_alias, {}):
- common_config = config.get("common_config", {})
- return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
+ if model_alias in kv:
+ common = kv[model_alias].get("common_config", {})
+ configs = []
+ for config in kv[model_alias].get("models", []):
+ merged = deep_merge(common, config)
+ shuffle = bool(merged.get("strategy", "fallback") == "load-balance") # load-balance or fallback
+ for model_id in strings_list(merged["model_id"], shuffle=shuffle):
+ merged["model_id"] = model_id
+ configs.append(merged.copy())
+ return configs
if not fallback_to_default:
return []
@@ -203,6 +210,10 @@ async def get_config_by_model_name(model_name: str) -> list[dict]:
common_config = config.get("common_config", {})
for model in config.get("models", []):
model_config = deep_merge(common_config, model)
- if model_config.get("model_name", "") == model_name:
- model_configs.append(model_config)
+ if model_config.get("model_name", "") != model_name:
+ continue
+ shuffle = bool(model_config.get("strategy", "fallback") == "load-balance") # load-balance or fallback
+ for model_id in strings_list(model_config["model_id"], shuffle=shuffle):
+ model_config["model_id"] = model_id
+ model_configs.append(model_config.copy())
return model_configs