Commit 7b31448

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-05-11 02:23:06
feat(ai): add load balancing strategy for text generation models
1 parent 7032232
Changed files (1)
src
ai
texts
src/ai/texts/models.py
@@ -9,6 +9,7 @@ from ai.utils import BOT_TIPS, EMOJI_TEXT_BOT, deep_merge, text_generation_docs
 from config import AI, PREFIX
 from database.kv import get_cf_kv
 from messages.utils import startswith_prefix
+from utils import strings_list
 
 
 # ruff: noqa: RUF002
@@ -153,10 +154,16 @@ async def get_config_by_model_alias(model_alias: str, *, fallback_to_default: bo
         model_config
     """
     kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
-
-    if config := kv.get(model_alias, {}):
-        common_config = config.get("common_config", {})
-        return [deep_merge(common_config, model_config) for model_config in config.get("models", [])]
+    if model_alias in kv:
+        common = kv[model_alias].get("common_config", {})
+        configs = []
+        for config in kv[model_alias].get("models", []):
+            merged = deep_merge(common, config)
+            shuffle = bool(merged.get("strategy", "fallback") == "load-balance")  # load-balance or fallback
+            for model_id in strings_list(merged["model_id"], shuffle=shuffle):
+                merged["model_id"] = model_id
+                configs.append(merged.copy())
+        return configs
 
     if not fallback_to_default:
         return []
@@ -203,6 +210,10 @@ async def get_config_by_model_name(model_name: str) -> list[dict]:
         common_config = config.get("common_config", {})
         for model in config.get("models", []):
             model_config = deep_merge(common_config, model)
-            if model_config.get("model_name", "") == model_name:
-                model_configs.append(model_config)
+            if model_config.get("model_name", "") != model_name:
+                continue
+            shuffle = bool(model_config.get("strategy", "fallback") == "load-balance")  # load-balance or fallback
+            for model_id in strings_list(model_config["model_id"], shuffle=shuffle):
+                model_config["model_id"] = model_id
+                model_configs.append(model_config.copy())
     return model_configs