Commit b346a67

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-17 13:40:04
style(gemini): better response parsing and grounding support
1 parent 4421d8f
Changed files (2)
src
src/llm/gemini.py
@@ -123,10 +123,16 @@ async def gemini_stream(
     *,
     silent: bool = False,
     append_grounding: bool = True,
+    single_thinking_msg: bool = True,
+    remove_thinking: bool = True,
     **kwargs,
 ) -> dict:
     """Gemini stream response.
 
+    Args:
+        single_thinking_msg (bool, optional): Only use one message for displaying thinking.
+        remove_thinking (bool, optional): Remove thinking parts once finished.
+
     Returns:
         dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
     """
@@ -164,13 +170,14 @@ async def gemini_stream(
             params["config"].thinking_config = None
         sent_messages = []
         is_reasoning = False
-        is_reasoning_conversation = None  # to  indicate whether it is a reasoning conversation
+        is_reasoning_conversation = None  # to indicate whether it is a reasoning conversation
         async for chunk in await app.aio.models.generate_content_stream(**params):
-            resp = parse_response(chunk.model_dump(), append_grounding=append_grounding)
+            resp = parse_response(chunk.model_dump())
             answer = resp.get("texts", "")
             thinking = resp.get("thinking", "")
             if is_reasoning_conversation is None and thinking:
                 is_reasoning_conversation = True
+
             if thinking and not is_reasoning:  # First time receiving reasoning content
                 is_reasoning = True
                 runtime_texts += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{thinking.lstrip()}"
@@ -178,7 +185,7 @@ async def gemini_stream(
                 runtime_texts += thinking
             elif is_reasoning_conversation is True and is_reasoning:  # Receiving response, close reasoning flag
                 is_reasoning = False
-                runtime_texts = f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+                runtime_texts = answer.lstrip() if remove_thinking else f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
             else:
                 runtime_texts += answer
 
@@ -193,14 +200,18 @@ async def gemini_stream(
                 parts = await smart_split(prefix + runtime_texts)
                 if len(parts) == 1:
                     continue
-                await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
-                runtime_texts = parts[-1]  # keep the last part
-                if is_reasoning:
-                    runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
-                if not silent:
-                    status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
-                    sent_messages.append(status_msg)
-                    status_mid = status_msg.id
+                if is_reasoning and single_thinking_msg:
+                    runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{parts[-1].lstrip()}"  # remove previous thinking
+                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
+                else:
+                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
+                    runtime_texts = parts[-1]  # keep the last part
+                    if is_reasoning:
+                        runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
+                    if not silent:
+                        status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
+                        sent_messages.append(status_msg)
+                        status_mid = status_msg.id
 
         # all chunks are processed
         if not answers.strip() and not thoughts.strip():  # empty response
@@ -216,14 +227,18 @@ async def gemini_stream(
                 append_grounding=append_grounding,
                 **kwargs,
             )
-
-        if await count_without_entities(prefix + thoughts + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
+        if append_grounding:  # add grounding to the response
+            answers = add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
+            runtime_texts = add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
+        final_thoughts = "" if remove_thinking else thoughts
+        if await count_without_entities(prefix + final_thoughts + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
             if length > GPT.COLLAPSE_LENGTH:  # collapse the response if the answer is too long
-                quoted = REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if thoughts.strip() else answers.strip()
+                quoted = REASONING_BEGIN + final_thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if final_thoughts.strip() else answers.strip()
                 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
             else:
-                quoted = blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" if thoughts.strip() else ""
+                quoted = blockquote(REASONING_BEGIN + final_thoughts.strip() + REASONING_END) + "\n" if final_thoughts.strip() else ""
                 await modify_progress(message=status_msg, text=f"{prefix}{quoted}{answers}", force_update=True)
+        # total length is too long, answers are splitted into multiple messages
         elif length > GPT.COLLAPSE_LENGTH:
             await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
         else:
@@ -293,9 +308,11 @@ async def gemini_nonstream(
             params["config"].response_modalities = ["TEXT"]
         response = await app.aio.models.generate_content(**params)
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
-        res = parse_response(response.model_dump(), append_grounding=append_grounding)
+        res = parse_response(response.model_dump())
         texts = res.get("texts", "")
         thoughts = res.get("thoughts", "")
+        if append_grounding:  # add grounding to the response
+            texts = add_grounding_results(texts, res["grounding_chunks"], res["grounding_supports"])
         results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
         media = res.get("media", [])
         total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
@@ -325,13 +342,12 @@ async def gemini_nonstream(
     return results
 
 
-def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
+def parse_response(data: dict) -> dict:
     """Parse gemini response, includes texts, image and websearch."""
     parts = glom(data, "candidates.0.content.parts", default=[]) or []
     gemini_logging(parts)
     grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
-    if not append_grounding:
-        grounding_chunks = []
+    grounding_supports = glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or []
     texts = ""
     thinking = ""
     media = []
@@ -348,13 +364,30 @@ def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
             save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
             image.save(save_path)
             media.append({"photo": save_path})
+    return {
+        "texts": beautify_llm_response(texts, newline_level=2),
+        "thinking": beautify_llm_response(thinking, newline_level=2),
+        "media": media,
+        "grounding_chunks": grounding_chunks,
+        "grounding_supports": grounding_supports,
+    }
+
+
+def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
+    index2url = {idx + 1: glom(chunk, "web.uri", default="https://www.google.com") for idx, chunk in enumerate(grounding_chunks)}
+    for support in grounding_supports:
+        indices: list[int] = support.get("grounding_chunk_indices", [])
+        indices_with_url = " ".join([f"[[{idx + 1}]]({index2url[idx + 1]})" for idx in indices])
+        if segment := glom(support, "segment.text", default=""):
+            answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
     for idx, grounding in enumerate(grounding_chunks):
         if idx > 9:
             break
         title = glom(grounding, "web.title", default="Web")
         url = glom(grounding, "web.uri", default="https://www.google.com")
-        texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
-    return {"texts": beautify_llm_response(texts, newline_level=2), "thinking": beautify_llm_response(thinking, newline_level=2), "media": media}
+        if url in answers:
+            answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+    return answers
 
 
 def gemini_logging(contexts: list):
src/messages/progress.py
@@ -58,7 +58,7 @@ async def modify_progress(
         if not detail_progress:
             return
         if len(text) > TEXT_LENGTH:
-            text = text[: 2 * TEXT_LENGTH]  # trim the very long texts
+            text = text[: 10 * TEXT_LENGTH]  # trim the very long texts
             text = (await smart_split(text))[0]
         logger.trace(f"Progress: {text!r}")
         await message.edit_text(text)