Commit b346a67
Changed files (2)
src
llm
messages
src/llm/gemini.py
@@ -123,10 +123,16 @@ async def gemini_stream(
*,
silent: bool = False,
append_grounding: bool = True,
+ single_thinking_msg: bool = True,
+ remove_thinking: bool = True,
**kwargs,
) -> dict:
"""Gemini stream response.
+ Args:
+ single_thinking_msg (bool, optional): Only use one message for displaying thinking.
+ remove_thinking (bool, optional): Remove thinking parts once finished.
+
Returns:
dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
"""
@@ -164,13 +170,14 @@ async def gemini_stream(
params["config"].thinking_config = None
sent_messages = []
is_reasoning = False
- is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
+ is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
async for chunk in await app.aio.models.generate_content_stream(**params):
- resp = parse_response(chunk.model_dump(), append_grounding=append_grounding)
+ resp = parse_response(chunk.model_dump())
answer = resp.get("texts", "")
thinking = resp.get("thinking", "")
if is_reasoning_conversation is None and thinking:
is_reasoning_conversation = True
+
if thinking and not is_reasoning: # First time receiving reasoning content
is_reasoning = True
runtime_texts += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{thinking.lstrip()}"
@@ -178,7 +185,7 @@ async def gemini_stream(
runtime_texts += thinking
elif is_reasoning_conversation is True and is_reasoning: # Receiving response, close reasoning flag
is_reasoning = False
- runtime_texts = f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+ runtime_texts = answer.lstrip() if remove_thinking else f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
else:
runtime_texts += answer
@@ -193,14 +200,18 @@ async def gemini_stream(
parts = await smart_split(prefix + runtime_texts)
if len(parts) == 1:
continue
- await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
- runtime_texts = parts[-1] # keep the last part
- if is_reasoning:
- runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
- if not silent:
- status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
- sent_messages.append(status_msg)
- status_mid = status_msg.id
+ if is_reasoning and single_thinking_msg:
+ runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{parts[-1].lstrip()}" # remove previous thinking
+ await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
+ else:
+ await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
+ runtime_texts = parts[-1] # keep the last part
+ if is_reasoning:
+ runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
+ if not silent:
+ status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
+ sent_messages.append(status_msg)
+ status_mid = status_msg.id
# all chunks are processed
if not answers.strip() and not thoughts.strip(): # empty response
@@ -216,14 +227,18 @@ async def gemini_stream(
append_grounding=append_grounding,
**kwargs,
)
-
- if await count_without_entities(prefix + thoughts + answers) <= TEXT_LENGTH - 10: # short answer in single msg
+ if append_grounding: # add grounding to the response
+ answers = add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
+ runtime_texts = add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
+ final_thoughts = "" if remove_thinking else thoughts
+ if await count_without_entities(prefix + final_thoughts + answers) <= TEXT_LENGTH - 10: # short answer in single msg
if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
- quoted = REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if thoughts.strip() else answers.strip()
+ quoted = REASONING_BEGIN + final_thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if final_thoughts.strip() else answers.strip()
await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
else:
- quoted = blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" if thoughts.strip() else ""
+ quoted = blockquote(REASONING_BEGIN + final_thoughts.strip() + REASONING_END) + "\n" if final_thoughts.strip() else ""
await modify_progress(message=status_msg, text=f"{prefix}{quoted}{answers}", force_update=True)
+ # total length is too long, answers are splitted into multiple messages
elif length > GPT.COLLAPSE_LENGTH:
await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
else:
@@ -293,9 +308,11 @@ async def gemini_nonstream(
params["config"].response_modalities = ["TEXT"]
response = await app.aio.models.generate_content(**params)
prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- res = parse_response(response.model_dump(), append_grounding=append_grounding)
+ res = parse_response(response.model_dump())
texts = res.get("texts", "")
thoughts = res.get("thoughts", "")
+ if append_grounding: # add grounding to the response
+ texts = add_grounding_results(texts, res["grounding_chunks"], res["grounding_supports"])
results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
media = res.get("media", [])
total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
@@ -325,13 +342,12 @@ async def gemini_nonstream(
return results
-def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
+def parse_response(data: dict) -> dict:
"""Parse gemini response, includes texts, image and websearch."""
parts = glom(data, "candidates.0.content.parts", default=[]) or []
gemini_logging(parts)
grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
- if not append_grounding:
- grounding_chunks = []
+ grounding_supports = glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or []
texts = ""
thinking = ""
media = []
@@ -348,13 +364,30 @@ def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
image.save(save_path)
media.append({"photo": save_path})
+ return {
+ "texts": beautify_llm_response(texts, newline_level=2),
+ "thinking": beautify_llm_response(thinking, newline_level=2),
+ "media": media,
+ "grounding_chunks": grounding_chunks,
+ "grounding_supports": grounding_supports,
+ }
+
+
+def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
+ index2url = {idx + 1: glom(chunk, "web.uri", default="https://www.google.com") for idx, chunk in enumerate(grounding_chunks)}
+ for support in grounding_supports:
+ indices: list[int] = support.get("grounding_chunk_indices", [])
+ indices_with_url = " ".join([f"[[{idx + 1}]]({index2url[idx + 1]})" for idx in indices])
+ if segment := glom(support, "segment.text", default=""):
+ answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
for idx, grounding in enumerate(grounding_chunks):
if idx > 9:
break
title = glom(grounding, "web.title", default="Web")
url = glom(grounding, "web.uri", default="https://www.google.com")
- texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
- return {"texts": beautify_llm_response(texts, newline_level=2), "thinking": beautify_llm_response(thinking, newline_level=2), "media": media}
+ if url in answers:
+ answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+ return answers
def gemini_logging(contexts: list):
src/messages/progress.py
@@ -58,7 +58,7 @@ async def modify_progress(
if not detail_progress:
return
if len(text) > TEXT_LENGTH:
- text = text[: 2 * TEXT_LENGTH] # trim the very long texts
+ text = text[: 10 * TEXT_LENGTH] # trim the very long texts
text = (await smart_split(text))[0]
logger.trace(f"Progress: {text!r}")
await message.edit_text(text)