Commit 87c7c14
src/llm/gpt.py
@@ -48,6 +48,7 @@ async def gpt_response(
message: Message,
*,
custom_model_id: str = "",
+ custom_model_name: str = "",
enable_tools: bool = True,
**kwargs,
) -> dict:
@@ -60,6 +61,7 @@ async def gpt_response(
client (Client): The Pyrogram client.
message (Message): The trigger message object.
custom_model_id (str, optional): Custom model id.
+ custom_model_name (str, optional): Custom model name.
enable_tools (bool, optional): Whether to enable tools. Defaults to True.
Returns:
@@ -77,9 +79,11 @@ async def gpt_response(
if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GENIMG) and not message.reply_to_message:
await send2tg(client, message, texts=TEXT2IMG_HELP, **kwargs)
return {}
- model_id, resp_modality = get_model_id(info, message)
+ model_id, is_custom_id, resp_modality = get_model_id(info, message)
if not model_id:
return {}
+ if is_custom_id:
+ custom_model_id = model_id
# cache media_group message, only process once
if media_group_id := message.media_group_id:
if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -98,7 +102,7 @@ async def gpt_response(
await send2tg(client, message, texts=f"⚠️不支持自定义模型: {custom_model_id}\n\n⚙️支持自定义模型列表:\n{'\n'.join(allowed_model_ids)}", **kwargs)
return {}
if custom_model_id.lower() in [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS)]:
- return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_id, enable_tools=enable_tools, **kwargs)
+ return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_name or custom_model_id, enable_tools=enable_tools, **kwargs)
if model_id == GEMINI.TEXT_MODEL and not custom_model_id:
return await gemini_chat_completion(client, message, enable_tools=enable_tools, **kwargs)
src/llm/models.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
+import re
from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
@@ -25,39 +26,39 @@ def get_context_type(conversations: list[Message]) -> str:
return context_type
-def get_model_id(minfo: dict, message: Message) -> tuple[str, str]:
+def get_model_id(minfo: dict, message: Message) -> tuple[str, bool, str]:
"""Get model id with response modality.
Returns:
- (model_id, response_modality)
+ (model_id, is_custom_id, response_modality)
"""
# to avoid potential infinitely loop,
# we do not respond to bot message & GPT responses.
if minfo["is_bot"]:
- return "", ""
+ return "", False, ""
if BOT_TIPS in minfo["text"]:
- return "", ""
+ return "", False, ""
model_id, response_modality = get_model_id_from_envars(minfo)
if model_id:
- return model_id, response_modality
+ return model_id, False, response_modality
- model_id, response_modality = get_model_id_from_prefix(minfo)
+ model_id, is_custom_id, response_modality = get_model_id_from_prefix(minfo)
if not model_id:
- return "", ""
+ return "", is_custom_id, ""
# early return for non-text generation
if response_modality != "text":
- return model_id, response_modality
+ return model_id, is_custom_id, response_modality
# check if we need to fallback to omini model
conversations = get_conversations(message)
context_type = get_context_type(conversations) # {"type": "text", "error": None} # text, image
if context_type == "text": # no need to fallback if context type is text
- return model_id, response_modality
+ return model_id, is_custom_id, response_modality
if context_type in ["video", "audio", "voice"]: # currently, only Gemini supports audio/video
- return GEMINI.TEXT_MODEL, "text"
+ return GEMINI.TEXT_MODEL, is_custom_id, "text"
if (
(model_id == GPT.OPENAI_MODEL and not GPT.OPENAI_ACCEPT_IMAGE)
@@ -79,9 +80,9 @@ def get_model_id(minfo: dict, message: Message) -> tuple[str, str]:
text_providers, _ = enabled_providers()
# prefer gemini if OMNI_PROVIDER is not set
model_id = omni_providers.get(GPT.OMNI_PROVIDER.lower()) or GEMINI.TEXT_MODEL or omni_providers[text_providers[0]]
- return model_id, "text"
+ return model_id, is_custom_id, "text"
- return model_id, response_modality
+ return model_id, is_custom_id, response_modality
def get_model_id_from_envars(minfo: dict) -> tuple[str, str]:
@@ -130,26 +131,30 @@ def get_model_id_from_envars(minfo: dict) -> tuple[str, str]:
return "", ""
-def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
+def get_model_id_from_prefix(minfo: dict) -> tuple[str, bool, str]:
text_providers, img_providers = enabled_providers()
+ model_id = ""
+ resp_modality = "text"
+ # start with /prefix
if startswith_prefix(minfo["text"], prefix="/gpt") and "openai" in text_providers:
- return GPT.OPENAI_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/gemini") and "gemini" in text_providers:
- return GEMINI.TEXT_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/ds") and "deepseek" in text_providers:
- return GPT.DEEPSEEK_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/doubao") and "doubao" in text_providers:
- return GPT.DOUBAO_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/qwen") and "qwen" in text_providers:
- return GPT.QWEN_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/kimi") and "kimi" in text_providers:
- return GPT.KIMI_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
- return GPT.GROK_MODEL, "text"
- if startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
- return GEMINI.IMG_MODEL, "image"
+ model_id = GPT.OPENAI_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/gemini") and "gemini" in text_providers:
+ model_id = GEMINI.TEXT_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/ds") and "deepseek" in text_providers:
+ model_id = GPT.DEEPSEEK_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/doubao") and "doubao" in text_providers:
+ model_id = GPT.DOUBAO_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/qwen") and "qwen" in text_providers:
+ model_id = GPT.QWEN_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/kimi") and "kimi" in text_providers:
+ model_id = GPT.KIMI_MODEL
+ elif startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
+ model_id = GPT.GROK_MODEL
+ elif startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
+ model_id = GEMINI.IMG_MODEL
+ resp_modality = "image"
# start with /ai, auto detect model_id
- if startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
+ elif startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
providers = {
"openai": GPT.OPENAI_MODEL,
"deepseek": GPT.DEEPSEEK_MODEL,
@@ -161,26 +166,30 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
}
# prefer gemini if DEFAULT_PROVIDER is not set
model_id = providers.get(GPT.DEFAULT_PROVIDER.lower()) or GEMINI.TEXT_MODEL or providers[text_providers[0]]
- return model_id, "text"
+ if model_id:
+ return model_id, False, resp_modality
# is replying to AI response message
if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.OPENAI_MODEL_NAME}:{BOT_TIPS}") and "openai" in text_providers:
- return GPT.OPENAI_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.TEXT_MODEL_NAME}:{BOT_TIPS}") and "gemini" in text_providers:
- return GEMINI.TEXT_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DEEPSEEK_MODEL_NAME}:{BOT_TIPS}") and "deepseek" in text_providers:
- return GPT.DEEPSEEK_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DOUBAO_MODEL_NAME}:{BOT_TIPS}") and "doubao" in text_providers:
- return GPT.DOUBAO_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.QWEN_MODEL_NAME}:{BOT_TIPS}") and "qwen" in text_providers:
- return GPT.QWEN_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.KIMI_MODEL_NAME}:{BOT_TIPS}") and "kimi" in text_providers:
- return GPT.KIMI_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
- return GPT.GROK_MODEL, "text"
- if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
- return GEMINI.IMG_MODEL, "image"
- return "", ""
+ model_id = GPT.OPENAI_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.TEXT_MODEL_NAME}:{BOT_TIPS}") and "gemini" in text_providers:
+ model_id = GEMINI.TEXT_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DEEPSEEK_MODEL_NAME}:{BOT_TIPS}") and "deepseek" in text_providers:
+ model_id = GPT.DEEPSEEK_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DOUBAO_MODEL_NAME}:{BOT_TIPS}") and "doubao" in text_providers:
+ model_id = GPT.DOUBAO_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.QWEN_MODEL_NAME}:{BOT_TIPS}") and "qwen" in text_providers:
+ model_id = GPT.QWEN_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.KIMI_MODEL_NAME}:{BOT_TIPS}") and "kimi" in text_providers:
+ model_id = GPT.KIMI_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
+ model_id = GPT.GROK_MODEL
+ elif startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
+ model_id = GEMINI.IMG_MODEL
+ resp_modality = "image"
+ elif matched := re.match(rf"^🤖(.*?):{BOT_TIPS}", minfo["reply_text"]):
+ return matched.group(1).lower(), True, "text"
+ return model_id, False, resp_modality
def get_gpt_config(model_id: str = "") -> dict: