Commit 5c00853

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-24 07:40:22
refactor(gpt): return details from `gpt_response`
1 parent 3208742
src/llm/gemini.py
@@ -54,7 +54,7 @@ async def gemini_response(
     disable_thinking: bool = False,
     include_thoughts: bool = True,
     **kwargs,
-):
+) -> dict:
     r"""Get Gemini response.
 
     Args:
@@ -100,6 +100,7 @@ async def gemini_response(
         return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, **kwargs)
     except Exception as e:
         logger.error(e)
+    return {}
 
 
 async def gemini_stream(
@@ -116,6 +117,12 @@ async def gemini_stream(
     append_grounding: bool = True,
     **kwargs,
 ) -> dict:
+    """Gemini stream response.
+
+    Returns:
+        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+    """
+    # ruff: noqa: RUF001, RUF003
     if prefix is None:
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
     answers = ""  # all model responses
@@ -237,7 +244,13 @@ async def gemini_nonstream(
     clean_marks: bool = False,  # useful in image generation
     append_grounding: bool = True,
     **kwargs,
-):
+) -> dict:
+    """Gemini non-stream response.
+
+    Returns:
+        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+    """
+    results = {}
     try:
         if clean_marks:
             clean_gemini_sourcemarks(params["contents"])
@@ -245,7 +258,7 @@ async def gemini_nonstream(
         if kwargs.get("gemini_api_keys"):
             api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
         if retry > len(api_keys) - 1:
-            return None
+            return {}
         api_key = kwargs.get("gemini_api_key", api_keys[retry])
         http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
         http_options = hook_gemini_httpoptions(http_options, message)
@@ -255,21 +268,22 @@ async def gemini_nonstream(
         res = parse_response(response.model_dump(), append_grounding=append_grounding)
         texts = res.get("texts", "")
         thoughts = res.get("thoughts", "")
+        results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
         media = res.get("media", [])
         total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
         length = await count_without_entities(total)
         single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
         if length <= GPT.COLLAPSE_LENGTH:
-            await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
+            results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
         elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
             final = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + texts.strip()) if thoughts.strip() else prefix + blockquote(texts.strip())
-            await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
+            results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
         else:  # multiple messages
             for idx, txt in await smart_split(total, single_msg_length):
                 if idx == 0:
-                    await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
+                    results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
                 else:
-                    await send2tg(client, message, texts=txt, **kwargs)
+                    results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
         await modify_progress(del_status=True, **kwargs)
     except Exception as e:
         logger.error(e)
@@ -280,6 +294,7 @@ async def gemini_nonstream(
             error += f"\n{response}"
         await modify_progress(text=error, force_update=True, **kwargs)
         return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs)  # type: ignore
+    return results
 
 
 def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
src/llm/gpt.py
@@ -13,7 +13,7 @@ from llm.models import get_context_type, get_gpt_config, get_model_id
 from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
 from llm.tools import merge_tools_response
-from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files
+from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files, raw_reasoning
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -69,25 +69,28 @@ def is_gpt_conversation(message: Message) -> bool:
     return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
 
 
-async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs):
+async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
     """Get GPT response from Various API.
 
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
         gpt_stream (bool): Whether to use stream mode.
+
+    Returns:
+        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
     """
     # ruff: noqa: RET502, RET503
     info = parse_msg(message)
     # send docs if message == "/ai", without reply
     if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
-        return
+        return {}
     if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) and not message.reply_to_message:
         await send2tg(client, message, texts=AIGC_HELP, **kwargs)
-        return
+        return {}
     if not is_gpt_conversation(message):
-        return
+        return {}
     reply_text = ""
     if message.reply_to_message:
         reply_info = parse_msg(message.reply_to_message, silent=True)
@@ -96,7 +99,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
-            return
+            return {}
         cache.set(f"gpt-{info['cid']}-{media_group_id}", "1", ttl=120)
     kwargs["message_info"] = info  # save trigger message info
     conversations = get_conversations(message)
@@ -107,9 +110,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
 
     config = get_gpt_config(model_id)
     if not config["client"]["api_key"].strip():
-        return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+        await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+        return {}
     if not config["completions"]["model"].strip():
-        return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置模型ID, 请尝试其他命令\n\n{HELP}", **kwargs)
+        await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置模型ID, 请尝试其他命令\n\n{HELP}", **kwargs)
+        return {}
 
     config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
@@ -121,27 +126,45 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     config, response = await merge_tools_response(config, **kwargs)
     # skip send a new request if tool_model is the same as the current model
     if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
-        texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n\n{response['content']}"
+        texts = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n{response['content']}"
         length = await count_without_entities(texts)
         if length <= TEXT_LENGTH:
             await modify_progress(message=status_msg, text=texts, force_update=True, **kwargs)
+            final = {
+                "texts": response["content"],
+                "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+                "model_name": config["friendly_name"],
+                "sent_messages": [status_msg],
+            }
         else:
-            await send2tg(client, message, texts=texts, **kwargs)
+            final = {
+                "texts": response["content"],
+                "prefix": f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n",
+                "model_name": config["friendly_name"],
+                "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+            }
             await modify_progress(message=status_msg, del_status=True, **kwargs)
         llm_cleanup_files(config["completions"]["messages"])
-        return
-
+        return final
+    final = {}
     if not gpt_stream:
         response = await send_to_gpt(config, **kwargs)
         if content := response.get("content"):
             if reasoning := response.get("reasoning"):
+                final["thoughts"] = raw_reasoning(reasoning)
                 content = f"{reasoning}\n{content}"
                 texts = f"🤖**{response['model']}**:{BOT_TIPS}\n{content}"
             else:
                 texts = f"🤖**{response['model']}**:{BOT_TIPS}\n\n{content}"
             logger.debug(texts)
-            await send2tg(client, message, texts=texts, **kwargs)
+            final |= {
+                "texts": content,
+                "prefix": f"🤖**{response['model']}**:{BOT_TIPS}\n",
+                "model_name": config["friendly_name"],
+                "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
+            }
             await modify_progress(message=status_msg, del_status=True, **kwargs)
     else:
-        return await send_to_gpt_stream(client, status_msg, config, **kwargs)  # type: ignore
+        final = await send_to_gpt_stream(client, status_msg, config, **kwargs)  # type: ignore
     llm_cleanup_files(config["completions"]["messages"])
+    return final
src/llm/response_stream.py
@@ -13,7 +13,7 @@ from pyrogram.types import Message, ReplyParameters
 
 from config import GPT, TEXT_LENGTH
 from llm.hooks import pre_hooks
-from llm.utils import BOT_TIPS, REASONING_BEGIN, REASONING_END, add_search_results_to_response, beautify_llm_response, split_reasoning
+from llm.utils import BOT_TIPS, REASONING_BEGIN, REASONING_END, add_search_results_to_response, beautify_llm_response, raw_reasoning, split_reasoning
 from messages.progress import modify_progress
 from messages.utils import blockquote, count_without_entities, smart_split
 
@@ -22,10 +22,11 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
     """Get GPT response in stream mode.
 
     Returns:
-        {"content": str, "reasoning": str, "model": str}
+        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
     """
     # ruff: noqa: RUF001, RUF003
     prefix = f"🤖**{config['friendly_name']}**:{BOT_TIPS}\n"
+    final = {"prefix": prefix, "model_name": config["friendly_name"], "sent_messages": [status]}
     try:
         pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"))
         openai = AsyncOpenAI(**config["client"])
@@ -87,8 +88,13 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
                 if is_reasoning:
                     answers = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{answers.lstrip()}"
                 status = await client.send_message(status.chat.id, text=prefix + answers, reply_parameters=ReplyParameters(message_id=status.id))
+                final["sent_messages"].append(status)
         # all chunks are processed
         all_answers += answers
+
+        all_reasoning, all_texts = split_reasoning(answers)
+        final |= {"thoughts": raw_reasoning(all_reasoning), "texts": all_texts}
+
         all_answers = add_search_results_to_response(config.get("search_results", []), all_answers)
         length = await count_without_entities(answers)
 
@@ -112,7 +118,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
         await modify_progress(text=error, force_update=True, **kwargs)
         if retry < GPT.MAX_RETRY:
             return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
-    return {}
+    return final
 
 
 async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
src/llm/utils.py
@@ -245,6 +245,13 @@ def split_reasoning(text: str) -> tuple[str, str]:
     return reasoning.strip(), content.strip()
 
 
+def raw_reasoning(text: str) -> str:
+    """Extract raw reasoning from text."""
+    if matched := re.search(rf"{REASONING_BEGIN}(.*?){REASONING_END}", text, flags=re.DOTALL):
+        return matched.group(1)
+    return text
+
+
 def shuffle_keys(keys: str | list[str]) -> str:
     """Shuffle comma speparated string."""
     if isinstance(keys, str):