Commit 439b4f3
Changed files (4)
src/llm/gemini.py
@@ -50,9 +50,11 @@ async def gemini_response(
conversations: list[Message],
modality: str = "image",
*,
+ enable_tools: bool = True,
append_grounding: bool = True,
disable_thinking: bool = False,
include_thoughts: bool = True,
+ system_prompt: str | None = None,
**kwargs,
) -> dict:
r"""Get Gemini response.
@@ -73,7 +75,7 @@ async def gemini_response(
await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
- tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())] if modality == "text" else None
+ tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())]
# parse config from environment variable
genconfig = {}
with contextlib.suppress(Exception):
@@ -85,10 +87,13 @@ async def gemini_response(
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
genconfig |= {"response_modalities": response_modalities}
- if tools:
+ if enable_tools and modality == "text":
genconfig |= {"tools": tools}
- if GEMINI.PREFER_LANG and modality == "text":
+ if system_prompt is not None:
+ genconfig |= {"system_instruction": system_prompt}
+ elif GEMINI.PREFER_LANG and modality == "text":
genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
+
if thinking_budget is not None and not disable_thinking:
thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
src/llm/gpt.py
@@ -86,13 +86,25 @@ def is_gpt_conversation(minfo: dict) -> bool:
return startswith_prefix(minfo["reply_text"], prefix=[f"🤖{x}".lower() for x in model_names])
-async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
+async def gpt_response(
+ client: Client,
+ message: Message,
+ *,
+ gpt_stream: bool = True,
+ system_prompt: str | None = None,
+ enable_gpt_tools: bool = True,
+ enable_gemini_tools: bool = True,
+ **kwargs,
+) -> dict:
"""Get GPT response from Various API.
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
gpt_stream (bool): Whether to use stream mode.
+ system_prompt (str | None): System prompt.
+ use_gpt_tools (bool): can use GPT tools.
+ use_gemini_tools (bool): can use Gemini tools.
Returns:
dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
@@ -122,8 +134,17 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
context_type = get_context_type(conversations) # {"type": "text", "error": None} # text, image
model_id, resp_modality, sdk = get_model_id(info["text"], info["reply_text"], context_type)
if "gemini" in model_id.lower() and sdk == "gemini":
- return await gemini_response(client, message, conversations, resp_modality, **kwargs)
-
+ return await gemini_response(
+ client,
+ message,
+ conversations,
+ resp_modality,
+ system_prompt=system_prompt,
+ enable_gemini_tools=enable_gemini_tools,
+ **kwargs,
+ )
+
+ # GPT models
config = get_gpt_config(model_id)
if not config["client"]["api_key"].strip():
await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
@@ -137,29 +158,31 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
- config, response = await merge_tools_response(config, **kwargs)
- # skip send a new request if tool_model is the same as the current model
- if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
- texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
- length = await count_without_entities(texts)
- if length <= TEXT_LENGTH:
- await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
- final = {
- "texts": response["content"],
- "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
- "model_name": config["friendly_name"],
- "sent_messages": [status_msg],
- }
- else:
- final = {
- "texts": response["content"],
- "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
- "model_name": config["friendly_name"],
- "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
- }
- await modify_progress(message=status_msg, del_status=True, **kwargs)
- llm_cleanup_files(config["completions"]["messages"])
- return final
+
+ if enable_gpt_tools:
+ config, response = await merge_tools_response(config, **kwargs)
+ # skip send a new request if tool_model is the same as the current model
+ if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
+ texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
+ length = await count_without_entities(texts)
+ if length <= TEXT_LENGTH:
+ await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
+ final = {
+ "texts": response["content"],
+ "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+ "model_name": config["friendly_name"],
+ "sent_messages": [status_msg],
+ }
+ else:
+ final = {
+ "texts": response["content"],
+ "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+ "model_name": config["friendly_name"],
+ "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+ }
+ await modify_progress(message=status_msg, del_status=True, **kwargs)
+ llm_cleanup_files(config["completions"]["messages"])
+ return final
final = {}
if not gpt_stream:
response = await send_to_gpt(config, **kwargs)
@@ -179,6 +202,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
}
await modify_progress(message=status_msg, del_status=True, **kwargs)
else:
- final = await send_to_gpt_stream(client, status_msg, config, **kwargs) # type: ignore
+ final = await send_to_gpt_stream(client, status_msg, config, system_prompt=system_prompt, **kwargs) # type: ignore
llm_cleanup_files(config["completions"]["messages"])
return final
src/llm/hooks.py
@@ -9,10 +9,12 @@ from messages.parser import parse_msg
from utils import unicode_to_ascii
-def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
+def pre_hooks(client: dict, completions: dict, message_info: dict | None = None, system_prompt: str | None = None):
pre_openrouter_hook(client, completions)
pre_helicone_hook(client, message_info)
- if GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
+ if system_prompt is not None:
+ modify_prompts(completions["messages"], prompt=system_prompt, role="system", method="overwrite")
+ elif GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
modify_prompts(completions["messages"], prompt=f"请使用{GEMINI.PREFER_LANG}回复。", role="system", method="append")
completions["messages"] = refine_prompts(completions["messages"])
src/llm/response_stream.py
@@ -18,7 +18,15 @@ from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, smart_split
-async def send_to_gpt_stream(client: Client, status: Message, config: dict, retry: int = 0, **kwargs) -> dict:
+async def send_to_gpt_stream(
+ client: Client,
+ status: Message,
+ config: dict,
+ *,
+ retry: int = 0,
+ system_prompt: str | None = None,
+ **kwargs,
+) -> dict:
"""Get GPT response in stream mode.
Returns:
@@ -28,7 +36,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
prefix = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n"
final = {"prefix": prefix, "model_name": config["friendly_name"], "sent_messages": [status]}
try:
- pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
+ pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"), system_prompt=system_prompt)
openai = AsyncOpenAI(**config["client"])
logger.trace(config)
answers = prefix