Commit 23c1a05
Changed files (7)
src/llm/contexts.py
@@ -70,7 +70,9 @@ async def single_context(client: Client, message: Message) -> dict:
def clean_text(text: str) -> str:
if not text:
return ""
- return re.sub(rf"(.*?){BOT_TIPS}\)", "", text.removeprefix(PREFIX.GPT), flags=re.DOTALL).strip()
+ for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds"]:
+ text = text.removeprefix(prefix).strip()
+ return re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
info = parse_msg(message, silent=True)
role = "assistant" if f"{BOT_TIPS})" in info["text"] else "user"
src/llm/gpt.py
@@ -10,7 +10,7 @@ from config import DOWNLOAD_DIR, ENABLE, GPT, PREFIX, cache
from llm.contexts import get_conversation_contexts, get_conversations
from llm.models import get_model_config_with_contexts, get_model_type
from llm.response import merge_tools_response, send_to_gpt
-from llm.utils import llm_cleanup_files
+from llm.utils import BOT_TIPS, llm_cleanup_files
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -18,11 +18,15 @@ from messages.utils import equal_prefix, startswith_prefix
from utils import rand_number, save_txt
HELP = f"""🤖**GPT对话**
-当前模型:
+`{PREFIX.GPT}` 命令当前模型:
- 文本模型: **{GPT.TEXT_MODEL_NAME}**
- 图片模型: **{GPT.IMAGE_MODEL_NAME}**
- 视频模型(暂时禁用): **{GPT.VIDEO_MODEL_NAME}**
+`/gpt` 命令强制使用: **{GPT.OPENAI_MODEL_NAME}**
+`/gemini` 命令强制使用: **{GPT.GEMINI_MODEL_NAME}**
+`/ds` 命令强制使用: **{GPT.DEEPSEEK_MODEL_NAME}**
+
使用说明:
1. 在 `{PREFIX.GPT}` 后接提示词即可与GPT对话
2. 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
@@ -32,7 +36,7 @@ HELP = f"""🤖**GPT对话**
def is_gpt_conversation(message: Message) -> bool:
info = parse_msg(message)
- if startswith_prefix(info["text"], prefix=[PREFIX.GPT]):
+ if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds"]):
return True
# is replying to gpt-bot response message?
if not message.reply_to_message:
@@ -51,17 +55,33 @@ async def gpt_response(client: Client, message: Message, **kwargs):
client (Client): The Pyrogram client.
message (Message): The trigger message object.
"""
+ # ruff: noqa: RET502, RET503
if not ENABLE.GPT:
return
info = parse_msg(message)
# send docs if message == "/ai", without reply
- if equal_prefix(info["text"], prefix=[PREFIX.GPT]) and not message.reply_to_message:
+ if equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds"]) and not message.reply_to_message:
await send2tg(client, message, texts=HELP, **kwargs)
return
if not is_gpt_conversation(message):
return
+ # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek
+ force_model = "N/A"
+ if startswith_prefix(info["text"], prefix=["/gpt"]):
+ force_model = GPT.OPENAI_MODEL
+ if not GPT.OPENAI_API_KEY:
+ return await send2tg(client, message, texts=f"⚠️GPT暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+ elif startswith_prefix(info["text"], prefix=["/gemini"]):
+ force_model = GPT.GEMINI_MODEL
+ if not GPT.GEMINI_API_KEY:
+ return await send2tg(client, message, texts=f"⚠️Gemini暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+ elif startswith_prefix(info["text"], prefix=["/ds"]):
+ force_model = GPT.DEEPSEEK_MODEL
+ if not GPT.DEEPSEEK_API_KEY:
+ return await send2tg(client, message, texts=f"⚠️DeepSeek暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+
# cache media_group message, only process once
if media_group_id := message.media_group_id:
if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -74,7 +94,7 @@ async def gpt_response(client: Client, message: Message, **kwargs):
await send2tg(client, message, texts=model_type, **kwargs)
return
contexts = await get_conversation_contexts(client, conversations)
- config = get_model_config_with_contexts(model_type, contexts)
+ config = get_model_config_with_contexts(model_type, contexts, force_model)
msg = f"🤖{config['friendly_name']}: 思考中..."
if kwargs.get("show_progress"):
res = await send2tg(client, message, texts=msg, **kwargs)
@@ -86,7 +106,7 @@ async def gpt_response(client: Client, message: Message, **kwargs):
reasoning_model = f"推理模型: {response['reasoning_model']}\n\n" if response.get("reasoning_model") else ""
media = [{"document": save_txt(f"{reasoning_model}{reasoning}", f"{DOWNLOAD_DIR}/GPT-Reasoning-{rand_number()}.txt")}]
if content := response.get("content"):
- texts = f"{response['bot_msg_prefix']}\n\n{content}"
+ texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
logger.debug(texts)
await send2tg(client, message, texts=texts, media=media, **kwargs)
await modify_progress(del_status=True, **kwargs)
src/llm/models.py
@@ -6,7 +6,7 @@ from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
from config import GPT, PROXY
-from llm.utils import BOT_TIPS, change_system_prompt
+from llm.utils import change_system_prompt
from messages.parser import parse_msg
@@ -28,7 +28,7 @@ def get_model_type(conversations: list[Message]) -> str:
return model_type
-def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dict:
+def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_model: str = "N/A") -> dict:
"""Get GPT model config based on contexts, and return the config and adjusted contexts.
contexts:
@@ -49,6 +49,8 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL, "video": GPT.VIDEO_BASE_URL}
+ model = force_model if force_model != "N/A" else models[model_type]
+ model_name = model_names[model_type]
# setup configs
# params for OpenAI client
client = {
@@ -58,13 +60,28 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
"http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
}
+ if force_model == GPT.OPENAI_MODEL:
+ client["api_key"] = GPT.OPENAI_API_KEY
+ client["base_url"] = GPT.OPENAI_BASE_URL
+ model_name = GPT.OPENAI_MODEL_NAME
+ elif force_model == GPT.GEMINI_MODEL:
+ client["api_key"] = GPT.GEMINI_API_KEY
+ client["base_url"] = GPT.GEMINI_BASE_URL
+ model_name = GPT.GEMINI_MODEL_NAME
+ elif force_model == GPT.DEEPSEEK_MODEL:
+ client["api_key"] = GPT.DEEPSEEK_API_KEY
+ client["base_url"] = GPT.DEEPSEEK_BASE_URL
+ model_name = GPT.DEEPSEEK_MODEL_NAME
+
# params for `openai.chat.completions.create()`
- completions = {"model": models[model_type], "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+ completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
completions = model_hook(completions)
- completions |= openrouter_hook(client["base_url"])
+ completions |= openrouter_hook(client["base_url"]) # this line should be after setting `force_model``
+
+ if force_model != "N/A" and completions.get("extra_body"): # remove models fallback
+ completions["extra_body"].pop("models", None) # should be after hooks
return {
- "friendly_name": model_names[model_type],
- "bot_msg_prefix": f"🤖**{model_names[model_type]}**: ({BOT_TIPS})",
+ "friendly_name": model_name,
"client": client,
"completions": completions,
}
@@ -90,7 +107,8 @@ def model_hook(params: dict) -> dict:
# hook for deepseek-r1.
# Ref: https://github.com/deepseek-ai/DeepSeek-R1/tree/97612c28d06139aa25bb8bca5d632e1fccd70ffd?tab=readme-ov-file#usage-recommendations
# Ref: https://linux.do/t/topic/408247
- if "deepseek-r1" in params.get("model", "").lower():
+ model = params.get("model", "").lower()
+ if any(x in model for x in ["deepseek-r1", "think", "o1", "o3"]):
params["messages"] = change_system_prompt(
context=params.get("messages", []),
prompt="In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}",
src/llm/response.py
@@ -30,7 +30,6 @@ async def merge_tools_response(config: dict, **kwargs) -> dict:
completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
tools_config = {
"friendly_name": config["friendly_name"],
- "bot_msg_prefix": config["bot_msg_prefix"],
"client": {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY, "http_client": config["client"]["http_client"]},
"completions": add_tools(completions),
}
@@ -65,7 +64,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
retry: int, number of retries
Returns:
- {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
+ {"content": str, "reasoning": str, "model": str, "reasoning_model": str}
"""
try:
openai = AsyncOpenAI(**config["client"])
@@ -83,7 +82,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
await modify_progress(text=error, force_update=True, **kwargs)
if retry < GPT.MAX_RETRY:
return await send_to_gpt(config, retry=retry + 1, **kwargs)
- return {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": ""}
+ return {"content": "", "reasoning": "", "reasoning_model": ""}
async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
@@ -116,12 +115,12 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
"""Parse GPT response.
Returns:
- {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
+ {"content": str, "reasoning": str, "model": str, "reasoning_model": str}
"""
logger.debug(response)
choice = glom(response, "choices.0", default={})
if glom(choice, "message.tool_calls.0", default={}): # this is a function call response
- return response | {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": config["bot_msg_prefix"]}
+ return response | {"content": "", "reasoning": "", "reasoning_model": ""}
try:
content = glom(choice, "message.content", default="") or ""
reasoning, content = extract_reasoning(content) # extract reasoning from content (<think>...</think>)
@@ -129,13 +128,13 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
reasoning = glom(choice, "message.reasoning", default="") or ""
primary_model = glom(config, "completions.model", default="") or ""
used_model = glom(response, "model", default="") or ""
- response = {"content": content.strip(), "reasoning": reasoning.strip(), "reasoning_model": used_model, "bot_msg_prefix": config["bot_msg_prefix"]}
+ response = {"content": content.strip(), "model": config["friendly_name"], "reasoning": reasoning.strip(), "reasoning_model": used_model}
if not (used_model in primary_model or primary_model in used_model):
# do not use `!=` to compare. (deepseek/deepseek-r1:free != deepseek/deepseek-r1, gpt-4o != gpt-4o-2024-07-18)
used_model = beautify_model_name(used_model)
logger.warning(f"Fallback model {primary_model} -> {used_model}")
if ENABLE.GPT_WARN_FALLBACK:
- response["bot_msg_prefix"] = response["bot_msg_prefix"].replace(config["friendly_name"], used_model)
+ response["model"] = used_model
except Exception as e:
logger.error(f"Parse GPT response failed: {e}")
raise
src/llm/utils.py
@@ -84,23 +84,37 @@ def beautify_model_name(name: str) -> str:
Returns:
beautified model name
"""
- # example: openai/o1-preview:online
+ if not name:
+ return name
+ # example: openai/gpt-4o:online
# remove suffix ":"
- name = "".join(name.split(":")[:-1]) # openai/o1-preview
+ parts = name.split(":")
+ if len(parts) > 1:
+ name = "".join(parts[:-1]) # openai/gpt-4o
# remove prefix "/"
- name = name.split("/")[-1] # o1-preview
+ name = name.split("/")[-1] # gpt-4o
# remove "-latest"
name = name.replace("-latest", "")
- return name.replace("gpt", "GPT").replace("deepseek", "DeepSeek").title() # O1-Preview
+ return name.replace("gpt", "GPT").replace("gemini", "Gemini").replace("deepseek", "DeepSeek") # GPT-4o
def extract_reasoning(text: str) -> tuple[str, str]:
- pattern = r"<think>(.*?)</think>"
+ """Extract reasoning from text.
+
+ "<think>
+ {reasoning_content}
+ </think>
+
+ {content}"
+ """
reasoning = ""
- if matched := re.search(pattern, text, re.DOTALL):
+ if matched := re.search(r"<think>(.*?)</think>", text, re.DOTALL):
+ reasoning = matched.group(1)
+ text = re.sub(r"<think>(.*?)</think>", "", text, count=1, flags=re.DOTALL) # remove <think>...</think>
+ if matched := re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL):
reasoning = matched.group(1)
- text = re.sub(pattern, "", text, count=1, flags=re.DOTALL) # remove <think>...</think>
- return reasoning.strip(), text.strip()
+ text = re.sub(r"<thinking>(.*?)</thinking>", "", text, count=1, flags=re.DOTALL)
+ return reasoning.strip(), text.strip().removeprefix("{content}").strip()
src/config.py
@@ -143,9 +143,9 @@ class GPT: # see `llm/README.md`
# comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
FALLBACK_MODELS = os.getenv("GPT_FALLBACK_MODELS", "")
FALLBACK_TOOLS_MODELS = os.getenv("GPT_FALLBACK_TOOLS_MODELS", "") # comma separated fallback tool models for OpenRouter
- TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "gpt-4o") # custom name
- IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "gpt-4o")
- VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "glm-4v-plus")
+ TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "GPT-4o") # custom name
+ IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "GPT-4o")
+ VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "GLM-4V-Plus")
GLM_API_KEY = os.getenv("GPT_GLM_API_KEY", "")
GLM_BASE_URL = os.getenv("GPT_GLM_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
SEARCH_NUM_RESULTS = os.getenv("GPT_SEARCH_NUM_RESULTS", "5")
@@ -167,6 +167,21 @@ class GPT: # see `llm/README.md`
TOOLS_BASE_URL = os.getenv("GPT_TOOLS_BASE_URL", "https://api.openai.com/v1")
TOKEN_ENCODING = os.getenv("GPT_TOKEN_ENCODING", "o200k_base") # https://github.com/openai/tiktoken
MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
+ # /gemini command
+ GEMINI_MODEL = os.getenv("GPT_GEMINI_MODEL", "gemini-2.0-flash")
+ GEMINI_MODEL_NAME = os.getenv("GPT_GEMINI_MODEL_NAME", "Gemini-2.0-Flash")
+ GEMINI_API_KEY = os.getenv("GPT_GEMINI_API_KEY", "")
+ GEMINI_BASE_URL = os.getenv("GPT_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
+ # /gpt command
+ OPENAI_MODEL = os.getenv("GPT_OPENAI_MODEL", "gpt-4o")
+ OPENAI_MODEL_NAME = os.getenv("GPT_OPENAI_MODEL_NAME", "GPT-4o")
+ OPENAI_API_KEY = os.getenv("GPT_OPENAI_API_KEY", "")
+ OPENAI_BASE_URL = os.getenv("GPT_OPENAI_BASE_URL", "https://api.openai.com/v1")
+ # /ds command
+ DEEPSEEK_MODEL = os.getenv("GPT_DEEPSEEK_MODEL", "deepseek-r1")
+ DEEPSEEK_MODEL_NAME = os.getenv("GPT_DEEPSEEK_MODEL_NAME", "DeepSeek-R1")
+ DEEPSEEK_API_KEY = os.getenv("GPT_DEEPSEEK_API_KEY", "")
+ DEEPSEEK_BASE_URL = os.getenv("GPT_DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
class TID:
src/handler.py
@@ -77,7 +77,7 @@ async def handle_utilities(
"""
kwargs |= {"target_chat": target_chat, "reply_msg_id": reply_msg_id, "show_progress": show_progress, "detail_progress": detail_progress}
if ai:
- await gpt_response(client, message, **kwargs) # /ai
+ await gpt_response(client, message, **kwargs) # /ai /gpt /gemini /ds
if asr:
await voice_to_text(client, message, **kwargs) # /asr
if audio:
@@ -275,7 +275,7 @@ def get_social_media_help(prefixes: list[str] | None = None):
if ENABLE.AUDIO:
msg += f"\n🎧**视频转音频**: `{PREFIX.AUDIO}` 回复视频消息"
if ENABLE.GPT:
- msg += f"\n🤖**GPT对话**: `{PREFIX.GPT}` + 提示词"
+ msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds` + 提示词"
if ENABLE.SUBTITLE:
msg += f"\n📃**提取字幕**: `{PREFIX.SUBTITLE}` + 油管链接 (或回复油管链接)"
if ENABLE.WGET: