Commit 393c510

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-21 20:08:12
feat(gemini): support reasoning content in Gemini responses
1 parent 18327d6
Changed files (3)
src/llm/gemini.py
@@ -11,12 +11,22 @@ from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSe
 from loguru import logger
 from PIL import Image
 from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
 from pyrogram.types import Message, ReplyParameters
 
 from config import CAPTION_LENGTH, DOWNLOAD_DIR, GEMINI, GPT, PREFIX, TEXT_LENGTH
 from llm.contexts import get_conversation_contexts
 from llm.hooks import hook_gemini_httpoptions
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_cmd_prefix, clean_gemini_sourcemarks, clean_source_marks, shuffle_keys
+from llm.utils import (
+    BOT_TIPS,
+    REASONING_BEGIN,
+    REASONING_END,
+    beautify_llm_response,
+    clean_cmd_prefix,
+    clean_gemini_sourcemarks,
+    clean_source_marks,
+    shuffle_keys,
+)
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -67,10 +77,10 @@ async def gemini_response(client: Client, message: Message, conversations: list[
         if tools:
             genconfig |= {"tools": tools}
         if GEMINI.PREFER_LANG and modality == "text":
-            genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}回复"}
+            genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
         if thinking_budget is not None:
             thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
-            genconfig |= {"thinking_config": ThinkingConfig(thinking_budget=thinking_budget)}
+            genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=True, thinking_budget=thinking_budget)}
         params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
         logger.trace(params)
         if modality == "image":
@@ -97,6 +107,7 @@ async def gemini_stream(
     if prefix is None:
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
     answers = ""  # all model responses
+    thoughts = ""  # all model thoughts
     runtime_texts = ""  # for a single telegram message
     init_status_msg = None if silent else kwargs.get("progress")
     status_msg = init_status_msg
@@ -115,10 +126,26 @@ async def gemini_stream(
         http_options = hook_gemini_httpoptions(http_options, message)
         app = genai.Client(api_key=api_key, http_options=http_options)
         sent_messages = []
+        is_reasoning = False
+        is_reasoning_conversation = None  # to  indicate whether it is a reasoning conversation
         async for chunk in await app.aio.models.generate_content_stream(**params):
             resp = parse_response(chunk.model_dump(), append_grounding=append_grounding)
             answer = resp.get("texts", "")
-            runtime_texts += answer
+            thinking = resp.get("thinking", "")
+            if is_reasoning_conversation is None and thinking:
+                is_reasoning_conversation = True
+            if thinking and not is_reasoning:  # First time receiving reasoning content
+                is_reasoning = True
+                runtime_texts += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{thinking.lstrip()}"
+            elif thinking and is_reasoning:  # Receiving reasoning content and is reasoning
+                runtime_texts += thinking
+            elif is_reasoning_conversation is True and is_reasoning:  # Receiving response, close reasoning flag
+                is_reasoning = False
+                runtime_texts = f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+            else:
+                runtime_texts += answer
+
+            thoughts += thinking
             answers += answer
             runtime_texts = beautify_llm_response(runtime_texts)
             length = await count_without_entities(prefix + runtime_texts)
@@ -131,13 +158,15 @@ async def gemini_stream(
                     continue
                 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
                 runtime_texts = parts[-1]  # keep the last part
+                if is_reasoning:
+                    runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
                 if not silent:
                     status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
                     sent_messages.append(status_msg)
                     status_mid = status_msg.id
 
         # all chunks are processed
-        if not answers.strip():  # empty response
+        if not answers.strip() and not thoughts.strip():  # empty response
             return await gemini_stream(
                 client,
                 message,
@@ -151,11 +180,13 @@ async def gemini_stream(
                 **kwargs,
             )
 
-        if await count_without_entities(prefix + answers) <= TEXT_LENGTH:  # short answer in single msg
+        if await count_without_entities(prefix + thoughts + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
             if length > GPT.COLLAPSE_LENGTH:  # collapse the response if the answer is too long
-                await modify_progress(message=status_msg, text=f"{prefix}{blockquote(runtime_texts)}", force_update=True)
+                quoted = REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if thoughts.strip() else answers.strip()
+                await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
             else:
-                await modify_progress(message=status_msg, text=f"{prefix}{runtime_texts}", force_update=True)
+                quoted = blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" if thoughts.strip() else ""
+                await modify_progress(message=status_msg, text=f"{prefix}{quoted}{answers}", force_update=True)
         elif length > GPT.COLLAPSE_LENGTH:
             await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
         else:
@@ -181,7 +212,7 @@ async def gemini_stream(
             append_grounding=append_grounding,
             **kwargs,
         )
-    return {"texts": answers, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
+    return {"texts": answers, "thoughts": thoughts, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
 
 
 async def gemini_nonstream(
@@ -211,11 +242,22 @@ async def gemini_nonstream(
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
         res = parse_response(response.model_dump(), append_grounding=append_grounding)
         texts = res.get("texts", "")
+        thoughts = res.get("thoughts", "")
         media = res.get("media", [])
-        length = await count_without_entities(prefix + texts)
+        total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
+        length = await count_without_entities(total)
         single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
-        texts = f"{prefix}{blockquote(texts)}" if GPT.COLLAPSE_LENGTH < length <= single_msg_length else f"{prefix}{texts}"
-        await send2tg(client, message, caption_above=True, texts=texts, media=media, **kwargs)
+        if length <= GPT.COLLAPSE_LENGTH:
+            await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
+        elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
+            final = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + texts.strip()) if thoughts.strip() else prefix + blockquote(texts.strip())
+            await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
+        else:  # multiple messages
+            for idx, txt in await smart_split(total, single_msg_length):
+                if idx == 0:
+                    await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
+                else:
+                    await send2tg(client, message, texts=txt, **kwargs)
         await modify_progress(del_status=True, **kwargs)
     except Exception as e:
         logger.error(e)
@@ -237,10 +279,14 @@ def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
     if not append_grounding:
         grounding_chunks = []
     texts = ""
+    thinking = ""
     media = []
     for item in parts:
         if item.get("text") is not None:
-            texts += item["text"]
+            if item.get("thought"):
+                thinking += item["text"]
+            else:
+                texts += item["text"]
         if item.get("inline_data") is not None:
             image = Image.open(BytesIO(item["inline_data"]["data"]))
             mime = item["inline_data"]["mime_type"]
@@ -254,7 +300,7 @@ def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
         title = glom(grounding, "web.title", default="Web")
         url = glom(grounding, "web.uri", default="https://www.google.com")
         texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
-    return {"texts": beautify_llm_response(texts, newline_level=2), "media": media}
+    return {"texts": beautify_llm_response(texts, newline_level=2), "thinking": beautify_llm_response(thinking, newline_level=2), "media": media}
 
 
 def gemini_logging(contexts: list):
src/llm/response_stream.py
@@ -33,7 +33,7 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
         answers = prefix
         all_answers = ""
         is_reasoning = False
-        reasoning_in_response = None
+        is_reasoning_conversation = None  # 用于指示是否是推理对话
         gen = await openai.chat.completions.create(**config["completions"], stream=True)
         async for chunk in gen:
             resp = chunk.model_dump()
@@ -46,14 +46,14 @@ async def send_to_gpt_stream(client: Client, status: Message, config: dict, retr
                 return {}
             answer = glom(resp, "choices.0.delta.content", default="") or ""
             reasoning_content = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
-            if reasoning_in_response is None and reasoning_content:
-                reasoning_in_response = True
+            if is_reasoning_conversation is None and reasoning_content:
+                is_reasoning_conversation = True
             if reasoning_content and not is_reasoning:  # 首次收到推理内容
                 is_reasoning = True
                 answers += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{reasoning_content.lstrip()}"
             elif reasoning_content and is_reasoning:  # 收到推理内容且正在思考
                 answers += reasoning_content
-            elif reasoning_in_response is True and is_reasoning:  # 收到回答, 关闭推理标志
+            elif is_reasoning_conversation is True and is_reasoning:  # 收到回答, 关闭推理标志
                 is_reasoning = False
                 answers = f"{answers.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
             else:
src/messages/utils.py
@@ -132,6 +132,7 @@ async def smart_split(text: str, chars_per_string: int = TEXT_LENGTH, mode: Pars
 
 def blockquote(s: str) -> str:
     """Block quote texts."""
+    s = s.replace(BLOCKQUOTE_EXPANDABLE_DELIM, "").replace(BLOCKQUOTE_EXPANDABLE_END_DELIM, "")
     return BLOCKQUOTE_EXPANDABLE_DELIM + s + "\n" + BLOCKQUOTE_EXPANDABLE_END_DELIM