Commit c8a0957
Changed files (6)
src/ai/texts/claude.py
@@ -13,7 +13,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
from pyrogram.types import Message, ReplyParameters
from ai.texts.contexts import get_anthropic_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
from config import AI, PROXY, TEXT_LENGTH
from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
@@ -36,6 +36,7 @@ async def anthropic_responses(
cache_response_ttl: int = 0,
anthropic_media_send_as: Literal["base64", "file_id"] = "file_id",
anthropic_append_citation: bool = True,
+ skills: str = "",
silent: bool = False,
max_retries: int = 3,
**kwargs,
@@ -84,6 +85,8 @@ async def anthropic_responses(
}
if literal_eval(anthropic_responses_config):
params |= literal_eval(anthropic_responses_config)
+ if skills:
+ params |= {"system": await load_skills(skills)}
logger.debug(f"anthropic.messages.create(**{params})")
resp = await single_api_response(
client,
src/ai/texts/gemini.py
@@ -12,7 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
from pyrogram.types import Message, ReplyParameters
from ai.texts.contexts import get_gemini_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
from config import AI, PROXY, TEXT_LENGTH
from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, quote, smart_split
@@ -33,6 +33,7 @@ async def gemini_chat_completion(
gemini_generate_content_config: str | dict = "",
gemini_proxy: str | None = PROXY.GOOGLE,
gemini_append_grounding: bool = True,
+ skills: str = "",
silent: bool = False,
max_retries: int = 3,
**kwargs,
@@ -56,6 +57,8 @@ async def gemini_chat_completion(
http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
gemini = genai.Client(api_key=api_key, http_options=http_options)
params: dict = {"model": model_id, "contents": await get_gemini_contexts(client, message, gemini)}
+ if skills:
+ gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
if conf := literal_eval(gemini_generate_content_config):
params["config"] = conf
logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
src/ai/texts/openai_chat.py
@@ -10,7 +10,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
from pyrogram.types import Message, ReplyParameters
from ai.texts.contexts import get_openai_completion_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, split_reasoning, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, split_reasoning, trim_none
from config import AI, PROXY, TEXT_LENGTH
from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
@@ -33,6 +33,7 @@ async def openai_chat_completions(
openai_system_prompt: str = "",
openai_contexts: list[dict] | None = None,
openai_tools: list[dict] | None = None,
+ skills: str = "",
silent: bool = False,
max_retries: int = 3,
**kwargs,
@@ -62,6 +63,8 @@ async def openai_chat_completions(
contexts = openai_contexts or await get_openai_completion_contexts(client, message)
if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
contexts.insert(0, {"role": "system", "content": openai_system_prompt})
+ if skills:
+ contexts = inject_skills(contexts, skills=await load_skills(skills))
params = {"model": model_id, "messages": contexts, "stream": True}
if literal_eval(openai_completions_config):
params |= literal_eval(openai_completions_config)
@@ -223,3 +226,18 @@ async def single_api_chat_completions(
**kwargs,
)
return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
+
+
+def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
+ if not skills:
+ return contexts
+ if glom(contexts, "0.role", default="") != "system":
+ contexts.insert(0, {"role": "system", "content": skills})
+ return contexts
+ system_prompt = contexts[0]["content"]
+ if isinstance(system_prompt, str) and skills not in system_prompt:
+ system_prompt = f"{system_prompt}\n{skills}"
+ if isinstance(system_prompt, list) and {"type": "text", "text": skills} not in system_prompt:
+ system_prompt.append({"type": "text", "text": skills})
+ contexts[0] = {"role": "system", "content": system_prompt}
+ return contexts
src/ai/texts/openai_response.py
@@ -12,7 +12,7 @@ from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
from pyrogram.types import Message, ReplyParameters
from ai.texts.contexts import get_openai_response_contexts
-from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, trim_none
+from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
from config import AI, PROXY, TEXT_LENGTH
from database.r2 import set_cf_r2
from messages.parser import get_thread_id
@@ -39,6 +39,7 @@ async def openai_responses_api(
openai_allow_video: bool = False, # whether to allow video in input modalities
openai_allow_file: bool = False, # whether to allow file in input modalities
openai_media_send_as: Literal["base64", "file_id"] = "file_id",
+ skills: str = "",
silent: bool = False,
max_retries: int = 3,
**kwargs,
@@ -91,6 +92,8 @@ async def openai_responses_api(
)
params = {}
params |= {"model": model_id, "stream": True, "input": contexts}
+ if skills:
+ params |= {"instructions": await load_skills(skills)}
if literal_eval(openai_responses_config):
params |= literal_eval(openai_responses_config)
if previous_response_id:
src/ai/texts/tool_call.py
@@ -119,7 +119,7 @@ Use the `web_search` tool to access up-to-date information from the web or when
"success": True,
"openai_system_prompt": texts, # add tool results to system prompt
"openai_tools": None, # disable tools after tool call
- "progress": kwargs["progress"],
+ "progress": kwargs.get("progress") or resp.get("progress"),
}
status_msg = kwargs.get("progress") or resp.get("progress")
return {"progress": status_msg} if isinstance(status_msg, Message) else {}
src/ai/utils.py
@@ -14,7 +14,7 @@ from google.genai.types import HttpOptions
from loguru import logger
from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM
-from config import AI, PREFIX, PROXY
+from config import AI, PREFIX, PROXY, cache
from database.kv import get_cf_kv
from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
@@ -185,6 +185,17 @@ async def clean_gemini_files():
await app.aio.files.delete(name=f.name)
+@cache.memoize(ttl=300)
+async def load_skills(skill_name: str) -> str:
+ skills = await get_cf_kv(skill_name)
+ skill_str = ""
+ if "SKILL.md" in skills:
+ skill_str = skills.pop("SKILL.md")
+ for fname, content in sorted(skills.items()):
+ skill_str += f"\n\nReference: {fname}\n{content}"
+ return skill_str
+
+
async def clean_anthropic_files():
"""Clean Anthropic files.