Commit 10c4aca

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-08-28 09:22:21
feat(ai): support custom LLM model ids
1 parent 590c6bc
src/llm/gemini/chat.py
@@ -25,6 +25,8 @@ async def gemini_chat_completion(
     client: Client,
     message: Message,
     *,
+    model_id: str = GEMINI.TEXT_MODEL,
+    model_name: str = GEMINI.TEXT_MODEL_NAME,
     enable_tools: bool = True,
     append_grounding: bool = True,
     disable_thinking: bool = False,
@@ -53,8 +55,8 @@ async def gemini_chat_completion(
         extra_config_str = GEMINI.TEXT_CONFIG
         genconfig = json.loads(extra_config_str)
     try:
-        real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
-        msg = f"🤖**{GEMINI.TEXT_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
+        real_prompt = clean_cmd_prefix(info["text"], model_id) or clean_cmd_prefix(info["reply_text"], model_id)
+        msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
         if not silent and kwargs.get("show_progress"):
             kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
         genconfig |= {"response_modalities": ["TEXT"]}
@@ -68,9 +70,9 @@ async def gemini_chat_completion(
         if GEMINI.TEXT_THINKING_BUDGET is not None and not disable_thinking:
             thinking_budget = min(round(float(GEMINI.TEXT_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
             genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
-        params = {"model": GEMINI.TEXT_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
+        params = {"model": model_id, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
         logger.trace(params)
-        return await gemini_stream(client, message, GEMINI.TEXT_MODEL_NAME, params, append_grounding=append_grounding, silent=silent, **kwargs)
+        return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, silent=silent, **kwargs)
     except Exception as e:
         logger.error(e)
     return {}
@@ -126,7 +128,7 @@ async def gemini_stream(
         app = genai.Client(api_key=api_key, http_options=http_options)
         # Construct the request params
         if "conversations" in params:  # convert conversations to contents
-            params["contents"] = await get_conversation_contexts(client, params["conversations"], ctx_format="gemini", app=app)
+            params["contents"] = await get_conversation_contexts(client, params["conversations"], model_id=params["model"], ctx_format="gemini", app=app)
         gemini_logging(params["contents"])
         tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"])  # type: ignore
         num_tokens = tokens.total_tokens or 0
src/llm/gemini/text2img.py
@@ -89,7 +89,7 @@ async def gemini_non_stream(
         app = genai.Client(api_key=api_key, http_options=http_options)
         # Construct the request params
         if "conversations" in params:  # convert conversations to contents
-            params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
+            params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), model_id=params["model"], ctx_format="gemini", app=app)
         clean_gemini_sourcemarks(params["contents"])
         genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
         response = await app.aio.models.generate_content(**genai_params)
src/llm/contexts.py
@@ -36,6 +36,7 @@ def get_conversations(message: Message) -> list[Message]:
 async def get_conversation_contexts(
     client: Client,
     conversations: list[Message],
+    model_id: str = "",
     ctx_format: str = "openai",
     app: genai.Client | AsyncOpenAI | None = None,
 ) -> list[dict]:
@@ -45,16 +46,16 @@ async def get_conversation_contexts(
     """
     # parse context for each message
     if ctx_format.lower() == "openai":
-        contexts = [await single_gpt_context(client, message) for message in conversations]
+        contexts = [await single_gpt_context(client, message, model_id) for message in conversations]
         contexts = [x for x in contexts if x.get("content")]
     else:
-        contexts = [await single_gemini_context(client, message, app) for message in conversations]  # type: ignore
+        contexts = [await single_gemini_context(client, message, app, model_id) for message in conversations]  # type: ignore
         contexts = [x for x in contexts if x.get("parts")]
 
     return contexts[: int(GPT.HISTORY_CONTEXT)]
 
 
-async def single_gpt_context(client: Client, message: Message) -> dict:
+async def single_gpt_context(client: Client, message: Message, model_id: str = "") -> dict:
     """Generate GPT contexts for a single message (Without considering reply message).
 
     Returns:
@@ -106,7 +107,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
                         }
                     )
             # user message has entity urls, use full html
-            clean_texts = clean_context(info["html"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
+            clean_texts = clean_context(info["html"], model_id) if role == "user" and info["entity_urls"] else clean_context(info["text"], model_id)
             if not clean_texts:
                 continue
             texts = f"[username]: {sender}\n[message]:\n{clean_texts}" if role == "user" and sender else clean_texts
@@ -117,7 +118,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
     return {"role": role, "content": contexts} if contexts else {}
 
 
-async def single_gemini_context(client: Client, message: Message, app: genai.Client) -> dict:
+async def single_gemini_context(client: Client, message: Message, app: genai.Client, model_id: str = "") -> dict:
     """Generate Gemini contexts for a single message (Without considering reply message).
 
     Returns:
@@ -170,7 +171,7 @@ async def single_gemini_context(client: Client, message: Message, app: genai.Cli
                     Path(fpath).unlink(missing_ok=True)
                     parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
             # user message has entity urls, use full html
-            clean_texts = clean_context(info["html"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
+            clean_texts = clean_context(info["html"], model_id) if role == "user" and info["entity_urls"] else clean_context(info["text"], model_id)
             if not clean_texts:
                 continue
             texts = f"[username]: {sender}\n[message]:\n{clean_texts}" if role == "user" and sender else clean_texts
src/llm/gpt.py
@@ -1,6 +1,9 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import re
+
+from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
@@ -16,6 +19,7 @@ from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix
+from utils import strings_list
 
 HELP = f"""🤖**GPT对话**
 `{PREFIX.GPT}` 后接提示词即可与GPT对话
@@ -43,16 +47,20 @@ async def gpt_response(
     client: Client,
     message: Message,
     *,
+    custom_model_id: str = "",
     enable_tools: bool = True,
     **kwargs,
 ) -> dict:
     """Get GPT response from Various API.
 
+    `/ai text`: get response from LLM
+    `/ai @gemini-2.5-flash text`: get response from gemini-2.5-flash (custom model id)
+
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
-        gpt_stream (bool): Whether to use stream mode.
-        enable_tools (bool): use tools.
+        custom_model_id (str, optional): Custom model id.
+        enable_tools (bool, optional): Whether to enable tools. Defaults to True.
 
     Returns:
         dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
@@ -80,14 +88,28 @@ async def gpt_response(
     kwargs["message_info"] = info  # save trigger message info
     if resp_modality == "image":
         return await text2img(client, message, enable_tools=enable_tools, **kwargs)
-    if model_id == GEMINI.TEXT_MODEL:
+
+    # handle custom model_id here
+    if matched := re.match(r"^/ai @([a-zA-Z0-9_\-\.]+)(\s+)?", info["text"]):  # match /ai @custom_model_id
+        custom_model_id = matched.group(1).strip()
+        logger.warning(f"Custom model id: {custom_model_id}")
+    allowed_model_ids = [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS) + strings_list(GPT.ALLOWED_CUSTOM_MODEL_IDS)]
+    if custom_model_id and custom_model_id.lower() not in allowed_model_ids:
+        await send2tg(client, message, texts=f"⚠️不支持自定义模型: {custom_model_id}\n\n⚙️支持自定义模型列表:\n{'\n'.join(allowed_model_ids)}", **kwargs)
+        return {}
+    if custom_model_id.lower() in [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS)]:
+        return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_id, enable_tools=enable_tools, **kwargs)
+    if model_id == GEMINI.TEXT_MODEL and not custom_model_id:
         return await gemini_chat_completion(client, message, enable_tools=enable_tools, **kwargs)
 
     # GPT models
+    if custom_model_id:
+        model_id = custom_model_id
     config = get_gpt_config(model_id)
+    config["friendly_name"] = custom_model_id or config["friendly_name"]
     conversations = get_conversations(message)
-    config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
-    real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+    config["completions"]["messages"] = await get_conversation_contexts(client, conversations, model_id=model_id, ctx_format="openai")
+    real_prompt = clean_cmd_prefix(info["text"], model_id) or clean_cmd_prefix(info["reply_text"], model_id)
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
src/llm/models.py
@@ -186,19 +186,24 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
 def get_gpt_config(model_id: str = "") -> dict:
     """Get GPT configurations."""
     model_factory = {
-        GPT.OPENAI_MODEL: {"api_key": sample_key(GPT.OPENAI_API_KEY), "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
-        GPT.DEEPSEEK_MODEL: {"api_key": sample_key(GPT.DEEPSEEK_API_KEY), "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
-        GPT.QWEN_MODEL: {"api_key": sample_key(GPT.QWEN_API_KEY), "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
-        GPT.DOUBAO_MODEL: {"api_key": sample_key(GPT.DOUBAO_API_KEY), "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
-        GPT.GROK_MODEL: {"api_key": sample_key(GPT.GROK_API_KEY), "base_url": GPT.GROK_BASE_URL, "model_name": GPT.GROK_MODEL_NAME},
-        GPT.KIMI_MODEL: {"api_key": sample_key(GPT.KIMI_API_KEY), "base_url": GPT.KIMI_BASE_URL, "model_name": GPT.KIMI_MODEL_NAME},
+        "gpt,chatgpt,o1,o3,o4": {"api_key": sample_key(GPT.OPENAI_API_KEY), "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
+        "deepseek": {"api_key": sample_key(GPT.DEEPSEEK_API_KEY), "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
+        "qwen,qvq,qwq": {"api_key": sample_key(GPT.QWEN_API_KEY), "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
+        "doubao": {"api_key": sample_key(GPT.DOUBAO_API_KEY), "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
+        "grok": {"api_key": sample_key(GPT.GROK_API_KEY), "base_url": GPT.GROK_BASE_URL, "model_name": GPT.GROK_MODEL_NAME},
+        "kimi": {"api_key": sample_key(GPT.KIMI_API_KEY), "base_url": GPT.KIMI_BASE_URL, "model_name": GPT.KIMI_MODEL_NAME},
     }
 
     client = {"http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT)}
     if GPT.TIMEOUT is not None:
         client |= {"timeout": int(GPT.TIMEOUT)}
 
-    model_id_config = model_factory.get(model_id, {})
+    model_id_config = {}
+    for prefix, config in model_factory.items():
+        if startswith_prefix(model_id, prefix):
+            model_id_config = config
+            break
+
     model_name = model_id_config.get("model_name", "")
     model_id_config.pop("model_name", None)
     client |= model_id_config
src/llm/summary.py
@@ -176,15 +176,11 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
         return
     await modify_progress(text=f"🤖AI总结中...\n{msg}", force_update=True, **kwargs)
     # Construct a message to call GPT
-    ai_msg = Message(
-        id=rand_number(),
-        chat=message.chat,
-        text=Str(GPT.SUMMARY_CMD),
-        reply_to_message=Message(id=rand_number(), chat=message.chat, text=Str(parsed["history"])),
-    )
+    ai_msg = Message(id=0, chat=message.chat, text=Str(f"/ai {parsed['history']}"))
     response = await gpt_response(
         client,
         ai_msg,
+        custom_model_id=GPT.CHAT_SUMMARY_MODEL_ID,
         system_prompt=SYSTEM_PROMPT,
         enable_tools=False,
         include_thoughts=False,
src/llm/utils.py
@@ -209,9 +209,11 @@ def image_emoji(capability: bool) -> str:  # noqa: FBT001
     return "🏞" if capability else ""
 
 
-def clean_cmd_prefix(text: str) -> str:
+def clean_cmd_prefix(text: str, model_id: str = "") -> str:
     for prefix in [*strings_list(PREFIX.GPT), PREFIX.GENIMG]:
         text = text.removeprefix(prefix).lstrip()
+    if model_id:
+        text = text.removeprefix(f"@{model_id}").lstrip()
     return text
 
 
@@ -229,11 +231,11 @@ def clean_reasoning(text: str) -> str:
     return text.removeprefix(BLOCKQUOTE_EXPANDABLE_END_DELIM).lstrip()
 
 
-def clean_context(text: str) -> str:
+def clean_context(text: str, model_id: str = "") -> str:
     """Remove bot prefix and reasoning content."""
     text = re.sub(r"^👤@.*?\/\/", "", text)  # remove markdown send_from_user
     text = re.sub(r"^👤\<a.*?tg://user\?id=\d+.*?@.*?</a>//", "", text)  # remove html send_from_user
-    text = clean_cmd_prefix(text)
+    text = clean_cmd_prefix(text, model_id)
     text = clean_bot_tips(text)
     return clean_reasoning(text)
 
src/others/podcast.py
@@ -19,7 +19,7 @@ from pyrogram.types import Chat, Message
 from pyrogram.types.messages_and_media.message import Str
 
 from asr.voice_recognition import asr_file
-from config import DB, DOWNLOAD_DIR, PODCAST, PREFIX, READING_SPEED, TZ, cache
+from config import DB, DOWNLOAD_DIR, GPT, PODCAST, READING_SPEED, TZ, cache
 from database.alist import upload_alist
 from database.r2 import get_cf_r2, set_cf_r2
 from llm.gpt import gpt_response
@@ -100,16 +100,15 @@ async def summary_pods(client: Client):
                 prompt = f"这是播客栏目《{feed_title}》的一期节目详情:\n节目标题: {entry['title']}\n节目播出日期: {pubdate}"
                 prompt += f"\n节目时长: {readable_time(entry['itunes_duration'])}\n节目简介: {desc}"
                 prompt += "\n请解读该播客内容, 只需关注内容本身, 不用概述播客的基本信息, 例如播客的标题, 日期, 时长等"
-                ai_cmd = next((x.strip() for x in PREFIX.GPT.split(",") if x.strip()), "")
                 # Construct a message to call GPT
                 cache.delete(f"parse_msg-{txt_msg.chat.id}-{txt_msg.id}")
                 ai_msg = Message(
                     id=txt_msg.id,
                     chat=txt_msg.chat,
-                    text=Str(f"{ai_cmd} {remove_img(prompt)}"),
+                    text=Str(f"/ai {remove_img(prompt)}"),
                     reply_to_message=Message(id=rand_number(), chat=message.chat, text=Str(subtitles)),
                 )
-                gpt_res = await gpt_response(client, ai_msg, include_thoughts=False, append_grounding=False, show_progress=True)
+                gpt_res = await gpt_response(client, ai_msg, custom_model_id=GPT.PODCAST_SUMMARY_MODEL_ID, include_thoughts=False, append_grounding=False, show_progress=True)
                 cache.delete(f"parse_msg-{txt_msg.chat.id}-{txt_msg.id}")
                 feed_item = match_item(feed_xml, entry)
                 update_item(saved_xml, feed_item, prefix_desc=gpt_res.get("texts", ""))
src/subtitles/subtitle.py
@@ -11,7 +11,7 @@ from pyrogram.types import Message
 from pyrogram.types.messages_and_media.message import Str
 
 from asr.voice_recognition import asr_file
-from config import ASR, DOWNLOAD_DIR, PREFIX, READING_SPEED, TEXT_LENGTH, cache
+from config import ASR, DOWNLOAD_DIR, GPT, PREFIX, READING_SPEED, TEXT_LENGTH, cache
 from llm.gpt import gpt_response
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -119,16 +119,17 @@ async def get_subtitle(client: Client, message: Message, *, to_telegraph: bool =
         if description.strip():
             prompt += f"节目简介: {description}\n"
         prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
-        ai_cmd = next((x.strip() for x in PREFIX.GPT.split(",") if x.strip()), "")
         # Construct a message to call GPT
         ai_msg = Message(
             id=subtitle_msg.id,
             chat=subtitle_msg.chat,
-            text=Str(f"{ai_cmd} {prompt}"),
+            text=Str(f"/ai {prompt}"),
             reply_to_message=Message(id=rand_number(), chat=subtitle_msg.chat, text=Str(subtitles)),
         )
-        kwargs["include_thoughts"] = False
-        await gpt_response(client, ai_msg, **kwargs)
+        kwargs |= {"include_thoughts": False, "append_grounding": False, "silent": True, "custom_model_id": GPT.SUBTITLE_SUMMARY_MODEL_ID}
+        res = await gpt_response(client, ai_msg, **kwargs)
+        if res.get("texts"):
+            await send2tg(client, ai_msg, texts=res["prefix"] + res["texts"], **kwargs)
     with contextlib.suppress(Exception):
         [await delete_message(msg) for msg in res.get("sent_messages", [])]
         await delete_message(kwargs.get("progress"))
src/config.py
@@ -360,6 +360,7 @@ class GPT:
     MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
     HELICONE_API_KEY = os.getenv("HELICONE_API_KEY", "")  # https://docs.helicone.ai/getting-started/integration-method/gateway
     COLLAPSE_LENGTH = int(os.getenv("GPT_COLLAPSE_LENGTH", "500"))  # Collapse the response if the length is larger than this value
+    ALLOWED_CUSTOM_MODEL_IDS = os.getenv("GPT_ALLOWED_CUSTOM_MODEL_IDS", "")  # comma separated OpenAI compatible model ids
     # comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
     OPENROUTER_FALLBACK_MODELS = os.getenv("GPT_OPENROUTER_FALLBACK_MODELS", "")
 
@@ -408,10 +409,12 @@ class GPT:
     KIMI_BASE_URL = os.getenv("GPT_KIMI_BASE_URL", "https://api.moonshot.ai/v1")
     KIMI_ACCEPT_IMAGE = os.getenv("GPT_KIMI_ACCEPT_IMAGE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
 
-    # AI summary (/summary)
-    SUMMARY_CMD = os.getenv("GPT_SUMMARY_CMD", "/gemini")  # add this command prefix to call AI summary
+    # AI summary
     # comma separated chat ids that are allowed to use `cid` as the chatid for the summary
     SUMMARY_WHITELIST_CUSTOM_CHATS = os.getenv("GPT_SUMMARY_WHITELIST_CUSTOM_CHATS", "")
+    CHAT_SUMMARY_MODEL_ID = os.getenv("CHAT_SUMMARY_MODEL_ID", "")  # Specify the model id for `/summary` command (If not set, use the default model)
+    PODCAST_SUMMARY_MODEL_ID = os.getenv("PODCAST_SUMMARY_MODEL_ID", "")  # for generating podcast summary (If not set, use the default AI model)
+    SUBTITLE_SUMMARY_MODEL_ID = os.getenv("SUBTITLE_SUMMARY_MODEL_ID", "")  # for generating podcast summary (If not set, use the default AI model)
     # For tool_call. Some models doesn't support tool call, so we use this model to do the tool_call first.
     # Then construct the new questions for the original model.
     TOOLS_MODEL = os.getenv("GPT_TOOLS_MODEL", "gpt-4o-mini")  # this model should be fast and cheap
@@ -427,6 +430,7 @@ class GEMINI:  # Official Gemini
     PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "")  # Set a prefer response language for Gemini
     MAX_THINKING_BUDGET = int(os.getenv("GEMINI_MAX_THINKING_BUDGET", "24576"))  # 24K
     CLEAN_FILES_AFTER_SECONDS = int(os.getenv("GEMINI_CLEAN_FILES_AFTER_SECONDS", "172800"))  # default to 48 hours
+    ALLOWED_CUSTOM_MODEL_IDS = os.getenv("GEMINI_ALLOWED_CUSTOM_MODEL_IDS", "")  # comma separated model ids
 
     # response modality: text
     TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-pro")