Commit 23c1a05

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-12 16:41:35
feat(gpt): add `/gpt, /gemini, /ds` commands
1 parent 51a1718
src/llm/contexts.py
@@ -70,7 +70,9 @@ async def single_context(client: Client, message: Message) -> dict:
     def clean_text(text: str) -> str:
         if not text:
             return ""
-        return re.sub(rf"(.*?){BOT_TIPS}\)", "", text.removeprefix(PREFIX.GPT), flags=re.DOTALL).strip()
+        for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds"]:
+            text = text.removeprefix(prefix).strip()
+        return re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
 
     info = parse_msg(message, silent=True)
     role = "assistant" if f"{BOT_TIPS})" in info["text"] else "user"
src/llm/gpt.py
@@ -10,7 +10,7 @@ from config import DOWNLOAD_DIR, ENABLE, GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_model_config_with_contexts, get_model_type
 from llm.response import merge_tools_response, send_to_gpt
-from llm.utils import llm_cleanup_files
+from llm.utils import BOT_TIPS, llm_cleanup_files
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -18,11 +18,15 @@ from messages.utils import equal_prefix, startswith_prefix
 from utils import rand_number, save_txt
 
 HELP = f"""🤖**GPT对话**
-当前模型:
+`{PREFIX.GPT}` 命令当前模型:
 - 文本模型: **{GPT.TEXT_MODEL_NAME}**
 - 图片模型: **{GPT.IMAGE_MODEL_NAME}**
 - 视频模型(暂时禁用): **{GPT.VIDEO_MODEL_NAME}**
 
+`/gpt` 命令强制使用: **{GPT.OPENAI_MODEL_NAME}**
+`/gemini` 命令强制使用: **{GPT.GEMINI_MODEL_NAME}**
+`/ds` 命令强制使用: **{GPT.DEEPSEEK_MODEL_NAME}**
+
 使用说明:
 1. 在 `{PREFIX.GPT}` 后接提示词即可与GPT对话
 2. 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
@@ -32,7 +36,7 @@ HELP = f"""🤖**GPT对话**
 
 def is_gpt_conversation(message: Message) -> bool:
     info = parse_msg(message)
-    if startswith_prefix(info["text"], prefix=[PREFIX.GPT]):
+    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds"]):
         return True
     # is replying to gpt-bot response message?
     if not message.reply_to_message:
@@ -51,17 +55,33 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
     """
+    # ruff: noqa: RET502, RET503
     if not ENABLE.GPT:
         return
     info = parse_msg(message)
     # send docs if message == "/ai", without reply
-    if equal_prefix(info["text"], prefix=[PREFIX.GPT]) and not message.reply_to_message:
+    if equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds"]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
 
     if not is_gpt_conversation(message):
         return
 
+    # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek
+    force_model = "N/A"
+    if startswith_prefix(info["text"], prefix=["/gpt"]):
+        force_model = GPT.OPENAI_MODEL
+        if not GPT.OPENAI_API_KEY:
+            return await send2tg(client, message, texts=f"⚠️GPT暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+    elif startswith_prefix(info["text"], prefix=["/gemini"]):
+        force_model = GPT.GEMINI_MODEL
+        if not GPT.GEMINI_API_KEY:
+            return await send2tg(client, message, texts=f"⚠️Gemini暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+    elif startswith_prefix(info["text"], prefix=["/ds"]):
+        force_model = GPT.DEEPSEEK_MODEL
+        if not GPT.DEEPSEEK_API_KEY:
+            return await send2tg(client, message, texts=f"⚠️DeepSeek暂时禁用, 请尝试其他命令\n\n{HELP}", **kwargs)
+
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -74,7 +94,7 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         await send2tg(client, message, texts=model_type, **kwargs)
         return
     contexts = await get_conversation_contexts(client, conversations)
-    config = get_model_config_with_contexts(model_type, contexts)
+    config = get_model_config_with_contexts(model_type, contexts, force_model)
     msg = f"🤖{config['friendly_name']}: 思考中..."
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
@@ -86,7 +106,7 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         reasoning_model = f"推理模型: {response['reasoning_model']}\n\n" if response.get("reasoning_model") else ""
         media = [{"document": save_txt(f"{reasoning_model}{reasoning}", f"{DOWNLOAD_DIR}/GPT-Reasoning-{rand_number()}.txt")}]
     if content := response.get("content"):
-        texts = f"{response['bot_msg_prefix']}\n\n{content}"
+        texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
         logger.debug(texts)
         await send2tg(client, message, texts=texts, media=media, **kwargs)
         await modify_progress(del_status=True, **kwargs)
src/llm/models.py
@@ -6,7 +6,7 @@ from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
 from config import GPT, PROXY
-from llm.utils import BOT_TIPS, change_system_prompt
+from llm.utils import change_system_prompt
 from messages.parser import parse_msg
 
 
@@ -28,7 +28,7 @@ def get_model_type(conversations: list[Message]) -> str:
     return model_type
 
 
-def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dict:
+def get_model_config_with_contexts(model_type: str, contexts: list[dict], force_model: str = "N/A") -> dict:
     """Get GPT model config based on contexts, and return the config and adjusted contexts.
 
     contexts:
@@ -49,6 +49,8 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
     apis = {"text": GPT.TEXT_API_KEY, "image": GPT.IMAGE_API_KEY, "video": GPT.VIDEO_API_KEY}
     urls = {"text": GPT.TEXT_BASE_URL, "image": GPT.IMAGE_BASE_URL, "video": GPT.VIDEO_BASE_URL}
 
+    model = force_model if force_model != "N/A" else models[model_type]
+    model_name = model_names[model_type]
     # setup configs
     # params for OpenAI client
     client = {
@@ -58,13 +60,28 @@ def get_model_config_with_contexts(model_type: str, contexts: list[dict]) -> dic
         "http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT),
     }
 
+    if force_model == GPT.OPENAI_MODEL:
+        client["api_key"] = GPT.OPENAI_API_KEY
+        client["base_url"] = GPT.OPENAI_BASE_URL
+        model_name = GPT.OPENAI_MODEL_NAME
+    elif force_model == GPT.GEMINI_MODEL:
+        client["api_key"] = GPT.GEMINI_API_KEY
+        client["base_url"] = GPT.GEMINI_BASE_URL
+        model_name = GPT.GEMINI_MODEL_NAME
+    elif force_model == GPT.DEEPSEEK_MODEL:
+        client["api_key"] = GPT.DEEPSEEK_API_KEY
+        client["base_url"] = GPT.DEEPSEEK_BASE_URL
+        model_name = GPT.DEEPSEEK_MODEL_NAME
+
     # params for `openai.chat.completions.create()`
-    completions = {"model": models[model_type], "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
+    completions = {"model": model, "messages": contexts, "temperature": float(GPT.TEMPERATURE)}
     completions = model_hook(completions)
-    completions |= openrouter_hook(client["base_url"])
+    completions |= openrouter_hook(client["base_url"])  # this line should be after setting `force_model``
+
+    if force_model != "N/A" and completions.get("extra_body"):  # remove models fallback
+        completions["extra_body"].pop("models", None)  # should be after hooks
     return {
-        "friendly_name": model_names[model_type],
-        "bot_msg_prefix": f"🤖**{model_names[model_type]}**: ({BOT_TIPS})",
+        "friendly_name": model_name,
         "client": client,
         "completions": completions,
     }
@@ -90,7 +107,8 @@ def model_hook(params: dict) -> dict:
     # hook for deepseek-r1.
     # Ref: https://github.com/deepseek-ai/DeepSeek-R1/tree/97612c28d06139aa25bb8bca5d632e1fccd70ffd?tab=readme-ov-file#usage-recommendations
     # Ref: https://linux.do/t/topic/408247
-    if "deepseek-r1" in params.get("model", "").lower():
+    model = params.get("model", "").lower()
+    if any(x in model for x in ["deepseek-r1", "think", "o1", "o3"]):
         params["messages"] = change_system_prompt(
             context=params.get("messages", []),
             prompt="In every output, response using the following format:\n<think>\n{reasoning_content}\n</think>\n\n{content}",
src/llm/response.py
@@ -30,7 +30,6 @@ async def merge_tools_response(config: dict, **kwargs) -> dict:
     completions |= openrouter_hook(GPT.TOOLS_BASE_URL, for_tools=True)
     tools_config = {
         "friendly_name": config["friendly_name"],
-        "bot_msg_prefix": config["bot_msg_prefix"],
         "client": {"base_url": GPT.TOOLS_BASE_URL, "api_key": GPT.TOOLS_API_KEY, "http_client": config["client"]["http_client"]},
         "completions": add_tools(completions),
     }
@@ -65,7 +64,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
         retry: int, number of retries
 
     Returns:
-        {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
+        {"content": str, "reasoning": str, "model": str, "reasoning_model": str}
     """
     try:
         openai = AsyncOpenAI(**config["client"])
@@ -83,7 +82,7 @@ async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
         await modify_progress(text=error, force_update=True, **kwargs)
         if retry < GPT.MAX_RETRY:
             return await send_to_gpt(config, retry=retry + 1, **kwargs)
-    return {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": ""}
+    return {"content": "", "reasoning": "", "reasoning_model": ""}
 
 
 async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
@@ -116,12 +115,12 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
     """Parse GPT response.
 
     Returns:
-        {"content": str, "reasoning": str, "reasoning_model": str, "bot_msg_prefix": str}
+        {"content": str, "reasoning": str, "model": str, "reasoning_model": str}
     """
     logger.debug(response)
     choice = glom(response, "choices.0", default={})
     if glom(choice, "message.tool_calls.0", default={}):  # this is a function call response
-        return response | {"content": "", "reasoning": "", "reasoning_model": "", "bot_msg_prefix": config["bot_msg_prefix"]}
+        return response | {"content": "", "reasoning": "", "reasoning_model": ""}
     try:
         content = glom(choice, "message.content", default="") or ""
         reasoning, content = extract_reasoning(content)  # extract reasoning from content (<think>...</think>)
@@ -129,13 +128,13 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
             reasoning = glom(choice, "message.reasoning", default="") or ""
         primary_model = glom(config, "completions.model", default="") or ""
         used_model = glom(response, "model", default="") or ""
-        response = {"content": content.strip(), "reasoning": reasoning.strip(), "reasoning_model": used_model, "bot_msg_prefix": config["bot_msg_prefix"]}
+        response = {"content": content.strip(), "model": config["friendly_name"], "reasoning": reasoning.strip(), "reasoning_model": used_model}
         if not (used_model in primary_model or primary_model in used_model):
             # do not use `!=` to compare. (deepseek/deepseek-r1:free != deepseek/deepseek-r1,  gpt-4o != gpt-4o-2024-07-18)
             used_model = beautify_model_name(used_model)
             logger.warning(f"Fallback model {primary_model} -> {used_model}")
             if ENABLE.GPT_WARN_FALLBACK:
-                response["bot_msg_prefix"] = response["bot_msg_prefix"].replace(config["friendly_name"], used_model)
+                response["model"] = used_model
     except Exception as e:
         logger.error(f"Parse  GPT response failed: {e}")
         raise
src/llm/utils.py
@@ -84,23 +84,37 @@ def beautify_model_name(name: str) -> str:
     Returns:
         beautified model name
     """
-    # example: openai/o1-preview:online
+    if not name:
+        return name
+    # example: openai/gpt-4o:online
 
     # remove suffix ":"
-    name = "".join(name.split(":")[:-1])  # openai/o1-preview
+    parts = name.split(":")
+    if len(parts) > 1:
+        name = "".join(parts[:-1])  # openai/gpt-4o
 
     # remove prefix "/"
-    name = name.split("/")[-1]  # o1-preview
+    name = name.split("/")[-1]  # gpt-4o
     # remove "-latest"
     name = name.replace("-latest", "")
 
-    return name.replace("gpt", "GPT").replace("deepseek", "DeepSeek").title()  # O1-Preview
+    return name.replace("gpt", "GPT").replace("gemini", "Gemini").replace("deepseek", "DeepSeek")  # GPT-4o
 
 
 def extract_reasoning(text: str) -> tuple[str, str]:
-    pattern = r"<think>(.*?)</think>"
+    """Extract reasoning from text.
+
+    "<think>
+    {reasoning_content}
+    </think>
+
+    {content}"
+    """
     reasoning = ""
-    if matched := re.search(pattern, text, re.DOTALL):
+    if matched := re.search(r"<think>(.*?)</think>", text, re.DOTALL):
+        reasoning = matched.group(1)
+        text = re.sub(r"<think>(.*?)</think>", "", text, count=1, flags=re.DOTALL)  # remove <think>...</think>
+    if matched := re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL):
         reasoning = matched.group(1)
-        text = re.sub(pattern, "", text, count=1, flags=re.DOTALL)  # remove <think>...</think>
-    return reasoning.strip(), text.strip()
+        text = re.sub(r"<thinking>(.*?)</thinking>", "", text, count=1, flags=re.DOTALL)
+    return reasoning.strip(), text.strip().removeprefix("{content}").strip()
src/config.py
@@ -143,9 +143,9 @@ class GPT:  # see `llm/README.md`
     # comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
     FALLBACK_MODELS = os.getenv("GPT_FALLBACK_MODELS", "")
     FALLBACK_TOOLS_MODELS = os.getenv("GPT_FALLBACK_TOOLS_MODELS", "")  # comma separated fallback tool models for OpenRouter
-    TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "gpt-4o")  # custom name
-    IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "gpt-4o")
-    VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "glm-4v-plus")
+    TEXT_MODEL_NAME = os.getenv("GPT_TEXT_MODEL_NAME", "GPT-4o")  # custom name
+    IMAGE_MODEL_NAME = os.getenv("GPT_IMAGE_MODEL_NAME", "GPT-4o")
+    VIDEO_MODEL_NAME = os.getenv("GPT_VIDEO_MODEL_NAME", "GLM-4V-Plus")
     GLM_API_KEY = os.getenv("GPT_GLM_API_KEY", "")
     GLM_BASE_URL = os.getenv("GPT_GLM_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
     SEARCH_NUM_RESULTS = os.getenv("GPT_SEARCH_NUM_RESULTS", "5")
@@ -167,6 +167,21 @@ class GPT:  # see `llm/README.md`
     TOOLS_BASE_URL = os.getenv("GPT_TOOLS_BASE_URL", "https://api.openai.com/v1")
     TOKEN_ENCODING = os.getenv("GPT_TOKEN_ENCODING", "o200k_base")  # https://github.com/openai/tiktoken
     MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
+    # /gemini command
+    GEMINI_MODEL = os.getenv("GPT_GEMINI_MODEL", "gemini-2.0-flash")
+    GEMINI_MODEL_NAME = os.getenv("GPT_GEMINI_MODEL_NAME", "Gemini-2.0-Flash")
+    GEMINI_API_KEY = os.getenv("GPT_GEMINI_API_KEY", "")
+    GEMINI_BASE_URL = os.getenv("GPT_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
+    # /gpt command
+    OPENAI_MODEL = os.getenv("GPT_OPENAI_MODEL", "gpt-4o")
+    OPENAI_MODEL_NAME = os.getenv("GPT_OPENAI_MODEL_NAME", "GPT-4o")
+    OPENAI_API_KEY = os.getenv("GPT_OPENAI_API_KEY", "")
+    OPENAI_BASE_URL = os.getenv("GPT_OPENAI_BASE_URL", "https://api.openai.com/v1")
+    # /ds command
+    DEEPSEEK_MODEL = os.getenv("GPT_DEEPSEEK_MODEL", "deepseek-r1")
+    DEEPSEEK_MODEL_NAME = os.getenv("GPT_DEEPSEEK_MODEL_NAME", "DeepSeek-R1")
+    DEEPSEEK_API_KEY = os.getenv("GPT_DEEPSEEK_API_KEY", "")
+    DEEPSEEK_BASE_URL = os.getenv("GPT_DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
 
 
 class TID:
src/handler.py
@@ -77,7 +77,7 @@ async def handle_utilities(
     """
     kwargs |= {"target_chat": target_chat, "reply_msg_id": reply_msg_id, "show_progress": show_progress, "detail_progress": detail_progress}
     if ai:
-        await gpt_response(client, message, **kwargs)  # /ai
+        await gpt_response(client, message, **kwargs)  # /ai /gpt /gemini /ds
     if asr:
         await voice_to_text(client, message, **kwargs)  # /asr
     if audio:
@@ -275,7 +275,7 @@ def get_social_media_help(prefixes: list[str] | None = None):
     if ENABLE.AUDIO:
         msg += f"\n🎧**视频转音频**: `{PREFIX.AUDIO}` 回复视频消息"
     if ENABLE.GPT:
-        msg += f"\n🤖**GPT对话**: `{PREFIX.GPT}` + 提示词"
+        msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds` + 提示词"
     if ENABLE.SUBTITLE:
         msg += f"\n📃**提取字幕**: `{PREFIX.SUBTITLE}` + 油管链接 (或回复油管链接)"
     if ENABLE.WGET: