Commit 439b4f3

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-15 03:57:33
feat(gpt): support custom system prompt and tool usage
1 parent f6238f8
src/llm/gemini.py
@@ -50,9 +50,11 @@ async def gemini_response(
     conversations: list[Message],
     modality: str = "image",
     *,
+    enable_tools: bool = True,
     append_grounding: bool = True,
     disable_thinking: bool = False,
     include_thoughts: bool = True,
+    system_prompt: str | None = None,
     **kwargs,
 ) -> dict:
     r"""Get Gemini response.
@@ -73,7 +75,7 @@ async def gemini_response(
         await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
     response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
     thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
-    tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())] if modality == "text" else None
+    tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())]
     # parse config from environment variable
     genconfig = {}
     with contextlib.suppress(Exception):
@@ -85,10 +87,13 @@ async def gemini_response(
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
         genconfig |= {"response_modalities": response_modalities}
-        if tools:
+        if enable_tools and modality == "text":
             genconfig |= {"tools": tools}
-        if GEMINI.PREFER_LANG and modality == "text":
+        if system_prompt is not None:
+            genconfig |= {"system_instruction": system_prompt}
+        elif GEMINI.PREFER_LANG and modality == "text":
             genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
+
         if thinking_budget is not None and not disable_thinking:
             thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
             genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
src/llm/gpt.py
@@ -86,13 +86,25 @@ def is_gpt_conversation(minfo: dict) -> bool:
     return startswith_prefix(minfo["reply_text"], prefix=[f"🤖{x}".lower() for x in model_names])
 
 
-async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
+async def gpt_response(
+    client: Client,
+    message: Message,
+    *,
+    gpt_stream: bool = True,
+    system_prompt: str | None = None,
+    enable_gpt_tools: bool = True,
+    enable_gemini_tools: bool = True,
+    **kwargs,
+) -> dict:
     """Get GPT response from Various API.
 
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
         gpt_stream (bool): Whether to use stream mode.
+        system_prompt (str | None): System prompt.
+        use_gpt_tools (bool): can use GPT tools.
+        use_gemini_tools (bool): can use Gemini tools.
 
     Returns:
         dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
@@ -122,8 +134,17 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     context_type = get_context_type(conversations)  # {"type": "text", "error": None}  # text, image
     model_id, resp_modality, sdk = get_model_id(info["text"], info["reply_text"], context_type)
     if "gemini" in model_id.lower() and sdk == "gemini":
-        return await gemini_response(client, message, conversations, resp_modality, **kwargs)
-
+        return await gemini_response(
+            client,
+            message,
+            conversations,
+            resp_modality,
+            system_prompt=system_prompt,
+            enable_gemini_tools=enable_gemini_tools,
+            **kwargs,
+        )
+
+    # GPT models
     config = get_gpt_config(model_id)
     if not config["client"]["api_key"].strip():
         await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
@@ -137,29 +158,31 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
-    config, response = await merge_tools_response(config, **kwargs)
-    # skip send a new request if tool_model is the same as the current model
-    if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
-        texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
-        length = await count_without_entities(texts)
-        if length <= TEXT_LENGTH:
-            await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
-            final = {
-                "texts": response["content"],
-                "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
-                "model_name": config["friendly_name"],
-                "sent_messages": [status_msg],
-            }
-        else:
-            final = {
-                "texts": response["content"],
-                "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
-                "model_name": config["friendly_name"],
-                "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
-            }
-            await modify_progress(message=status_msg, del_status=True, **kwargs)
-        llm_cleanup_files(config["completions"]["messages"])
-        return final
+
+    if enable_gpt_tools:
+        config, response = await merge_tools_response(config, **kwargs)
+        # skip send a new request if tool_model is the same as the current model
+        if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
+            texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
+            length = await count_without_entities(texts)
+            if length <= TEXT_LENGTH:
+                await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
+                final = {
+                    "texts": response["content"],
+                    "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+                    "model_name": config["friendly_name"],
+                    "sent_messages": [status_msg],
+                }
+            else:
+                final = {
+                    "texts": response["content"],
+                    "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+                    "model_name": config["friendly_name"],
+                    "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+                }
+                await modify_progress(message=status_msg, del_status=True, **kwargs)
+            llm_cleanup_files(config["completions"]["messages"])
+            return final
     final = {}
     if not gpt_stream:
         response = await send_to_gpt(config, **kwargs)
@@ -179,6 +202,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
             }
             await modify_progress(message=status_msg, del_status=True, **kwargs)
     else:
-        final = await send_to_gpt_stream(client, status_msg, config, **kwargs)  # type: ignore
+        final = await send_to_gpt_stream(client, status_msg, config, system_prompt=system_prompt, **kwargs)  # type: ignore
     llm_cleanup_files(config["completions"]["messages"])
     return final
src/llm/hooks.py
@@ -9,10 +9,12 @@ from messages.parser import parse_msg
 from utils import unicode_to_ascii
 
 
-def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
+def pre_hooks(client: dict, completions: dict, message_info: dict | None = None, system_prompt: str | None = None):
     pre_openrouter_hook(client, completions)
     pre_helicone_hook(client, message_info)
-    if GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
+    if system_prompt is not None:
+        modify_prompts(completions["messages"], prompt=system_prompt, role="system", method="overwrite")
+    elif GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
         modify_prompts(completions["messages"], prompt=f"请使用{GEMINI.PREFER_LANG}回复。", role="system", method="append")
     completions["messages"] = refine_prompts(completions["messages"])
 
src/llm/response_stream.py
@@ -18,7 +18,15 @@ from messages.progress import modify_progress
 from messages.utils import blockquote, count_without_entities, smart_split
 
 
-async def send_to_gpt_stream(client: Client, status: Message, config: dict, retry: int = 0, **kwargs) -> dict:
+async def send_to_gpt_stream(
+    client: Client,
+    status: Message,
+    config: dict,
+    *,
+    retry: int = 0,
+    system_prompt: str | None = None,
+    **kwargs,
+) -> dict:
     """Get GPT response in stream mode.
 
     Returns:
@@ -28,7 +36,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
     prefix = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n"
     final = {"prefix": prefix, "model_name": config["friendly_name"], "sent_messages": [status]}
     try:
-        pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
+        pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"), system_prompt=system_prompt)
         openai = AsyncOpenAI(**config["client"])
         logger.trace(config)
         answers = prefix