Commit 5c00853
Changed files (4)
src/llm/gemini.py
@@ -54,7 +54,7 @@ async def gemini_response(
disable_thinking: bool = False,
include_thoughts: bool = True,
**kwargs,
-):
+) -> dict:
r"""Get Gemini response.
Args:
@@ -100,6 +100,7 @@ async def gemini_response(
return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, **kwargs)
except Exception as e:
logger.error(e)
+ return {}
async def gemini_stream(
@@ -116,6 +117,12 @@ async def gemini_stream(
append_grounding: bool = True,
**kwargs,
) -> dict:
+ """Gemini stream response.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+ """
+ # ruff: noqa: RUF001, RUF003
if prefix is None:
prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
answers = "" # all model responses
@@ -237,7 +244,13 @@ async def gemini_nonstream(
clean_marks: bool = False, # useful in image generation
append_grounding: bool = True,
**kwargs,
-):
+) -> dict:
+ """Gemini non-stream response.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+ """
+ results = {}
try:
if clean_marks:
clean_gemini_sourcemarks(params["contents"])
@@ -245,7 +258,7 @@ async def gemini_nonstream(
if kwargs.get("gemini_api_keys"):
api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
if retry > len(api_keys) - 1:
- return None
+ return {}
api_key = kwargs.get("gemini_api_key", api_keys[retry])
http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
http_options = hook_gemini_httpoptions(http_options, message)
@@ -255,21 +268,22 @@ async def gemini_nonstream(
res = parse_response(response.model_dump(), append_grounding=append_grounding)
texts = res.get("texts", "")
thoughts = res.get("thoughts", "")
+ results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
media = res.get("media", [])
total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
length = await count_without_entities(total)
single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
if length <= GPT.COLLAPSE_LENGTH:
- await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
final = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + texts.strip()) if thoughts.strip() else prefix + blockquote(texts.strip())
- await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
else: # multiple messages
for idx, txt in await smart_split(total, single_msg_length):
if idx == 0:
- await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
else:
- await send2tg(client, message, texts=txt, **kwargs)
+ results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
await modify_progress(del_status=True, **kwargs)
except Exception as e:
logger.error(e)
@@ -280,6 +294,7 @@ async def gemini_nonstream(
error += f"\n{response}"
await modify_progress(text=error, force_update=True, **kwargs)
return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs) # type: ignore
+ return results
def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
src/llm/gpt.py
@@ -13,7 +13,7 @@ from llm.models import get_context_type, get_gpt_config, get_model_id
from llm.response import send_to_gpt
from llm.response_stream import send_to_gpt_stream
from llm.tools import merge_tools_response
-from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files
+from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files, raw_reasoning
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -69,25 +69,28 @@ def is_gpt_conversation(message: Message) -> bool:
return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
-async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs):
+async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
"""Get GPT response from Various API.
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
gpt_stream (bool): Whether to use stream mode.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
"""
# ruff: noqa: RET502, RET503
info = parse_msg(message)
# send docs if message == "/ai", without reply
if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]) and not message.reply_to_message:
await send2tg(client, message, texts=HELP, **kwargs)
- return
+ return {}
if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) and not message.reply_to_message:
await send2tg(client, message, texts=AIGC_HELP, **kwargs)
- return
+ return {}
if not is_gpt_conversation(message):
- return
+ return {}
reply_text = ""
if message.reply_to_message:
reply_info = parse_msg(message.reply_to_message, silent=True)
@@ -96,7 +99,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
# cache media_group message, only process once
if media_group_id := message.media_group_id:
if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
- return
+ return {}
cache.set(f"gpt-{info['cid']}-{media_group_id}", "1", ttl=120)
kwargs["message_info"] = info # save trigger message info
conversations = get_conversations(message)
@@ -107,9 +110,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
config = get_gpt_config(model_id)
if not config["client"]["api_key"].strip():
- return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+ await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+ return {}
if not config["completions"]["model"].strip():
- return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置模型ID, 请尝试其他命令\n\n{HELP}", **kwargs)
+ await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置模型ID, 请尝试其他命令\n\n{HELP}", **kwargs)
+ return {}
config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
@@ -121,27 +126,45 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
config, response = await merge_tools_response(config, **kwargs)
# skip send a new request if tool_model is the same as the current model
if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
- texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n\n{response['content']}"
+ texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
length = await count_without_entities(texts)
if length <= TEXT_LENGTH:
await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
+ final = {
+ "texts": response["content"],
+ "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+ "model_name": config["friendly_name"],
+ "sent_messages": [status_msg],
+ }
else:
- await send2tg(client, message, texts=texts, **kwargs)
+ final = {
+ "texts": response["content"],
+ "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+ "model_name": config["friendly_name"],
+ "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+ }
await modify_progress(message=status_msg, del_status=True, **kwargs)
llm_cleanup_files(config["completions"]["messages"])
- return
-
+ return final
+ final = {}
if not gpt_stream:
response = await send_to_gpt(config, **kwargs)
if content := response.get("content"):
if reasoning := response.get("reasoning"):
+ final["thoughts"] = raw_reasoning(reasoning)
content = f"{reasoning}\n{content}"
texts = f"🤖**{response['model']}**:{BOT_TIPS}\n{content}"
else:
texts = f"🤖**{response['model']}**:{BOT_TIPS}\n\n{content}"
logger.debug(texts)
- await send2tg(client, message, texts=texts, **kwargs)
+ final |= {
+ "texts": content,
+ "prefix": f"🤖**{response['model']}**:{BOT_TIPS}\n",
+ "model_name": config["friendly_name"],
+ "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+ }
await modify_progress(message=status_msg, del_status=True, **kwargs)
else:
- return await send_to_gpt_stream(client, status_msg, config, **kwargs) # type: ignore
+ final = await send_to_gpt_stream(client, status_msg, config, **kwargs) # type: ignore
llm_cleanup_files(config["completions"]["messages"])
+ return final
src/llm/response_stream.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message, ReplyParameters
from config import GPT, TEXT_LENGTH
from llm.hooks import pre_hooks
-from llm.utils import BOT_TIPS, REASONING_BEGIN, REASONING_END, add_search_results_to_response, beautify_llm_response, split_reasoning
+from llm.utils import BOT_TIPS, REASONING_BEGIN, REASONING_END, add_search_results_to_response, beautify_llm_response, raw_reasoning, split_reasoning
from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, smart_split
@@ -22,10 +22,11 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
"""Get GPT response in stream mode.
Returns:
- {"content": str, "reasoning": str, "model": str}
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
"""
# ruff: noqa: RUF001, RUF003
prefix = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n"
+ final = {"prefix": prefix, "model_name": config["friendly_name"], "sent_messages": [status]}
try:
pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
openai = AsyncOpenAI(**config["client"])
@@ -87,8 +88,13 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
if is_reasoning:
answers = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{answers.lstrip()}"
status = await client.send_message(status.chat.id, text=prefix + answers, reply_parameters=ReplyParameters(message_id=status.id))
+ final["sent_messages"].append(status)
# all chunks are processed
all_answers += answers
+
+ all_reasoning, all_texts = split_reasoning(answers)
+ final |= {"thoughts": raw_reasoning(all_reasoning), "texts": all_texts}
+
all_answers = add_search_results_to_response(config.get("search_results", []), all_answers)
length = await count_without_entities(answers)
@@ -112,7 +118,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
await modify_progress(text=error, force_update=True, **kwargs)
if retry < GPT.MAX_RETRY:
return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
- return {}
+ return final
async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
src/llm/utils.py
@@ -245,6 +245,13 @@ def split_reasoning(text: str) -> tuple[str, str]:
return reasoning.strip(), content.strip()
+def raw_reasoning(text: str) -> str:
+ """Extract raw reasoning from text."""
+ if matched := re.search(rf"{REASONING_BEGIN}(.*?){REASONING_END}", text, flags=re.DOTALL):
+ return matched.group(1)
+ return text
+
+
def shuffle_keys(keys: str | list[str]) -> str:
"""Shuffle comma speparated string."""
if isinstance(keys, str):