Commit 1f8922a
Changed files (2)
src
src/llm/gemini.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import contextlib
import json
from io import BytesIO
@@ -31,6 +32,7 @@ from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import blockquote, count_without_entities, smart_split
+from networking import flatten_rediercts
from utils import number_to_emoji, rand_string
HELP = f"""🌠**AI生图**
@@ -227,8 +229,8 @@ async def gemini_stream(
**kwargs,
)
if append_grounding: # add grounding to the response
- answers = add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
- runtime_texts = add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
+ answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
+ runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
final_thoughts = "" if remove_thinking else thoughts
if await count_without_entities(prefix + final_thoughts + answers) <= TEXT_LENGTH - 10: # short answer in single msg
if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
@@ -311,7 +313,7 @@ async def gemini_nonstream(
texts = res.get("texts", "")
thoughts = res.get("thoughts", "")
if append_grounding: # add grounding to the response
- texts = add_grounding_results(texts, res["grounding_chunks"], res["grounding_supports"])
+ texts = await add_grounding_results(texts, res["grounding_chunks"], res["grounding_supports"])
results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
media = res.get("media", [])
total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
@@ -372,8 +374,11 @@ def parse_response(data: dict) -> dict:
}
-def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
- index2url = {idx + 1: glom(chunk, "web.uri", default="https://www.google.com") for idx, chunk in enumerate(grounding_chunks)}
+async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
+ urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
+ tasks = [flatten_rediercts(url) for url in urls]
+ flatten_urls = await asyncio.gather(*tasks)
+ index2url = {idx + 1: url for idx, url in enumerate(flatten_urls)}
for support in grounding_supports:
indices: list[int] = support.get("grounding_chunk_indices", [])
indices_with_url = " ".join([f"[[{idx + 1}]]({index2url[idx + 1]})" for idx in indices])
@@ -383,7 +388,7 @@ def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_
if idx > 9:
break
title = glom(grounding, "web.title", default="Web")
- url = glom(grounding, "web.uri", default="https://www.google.com")
+ url = flatten_urls[idx]
if url in answers:
answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
return answers
src/networking.py
@@ -436,6 +436,10 @@ async def flatten_rediercts(texts: str | None = None, pattern: str | None = None
# shorturl.at
if matched := re.search(r"(https?://)?shorturl\.at/([^.。,,?&/\s]+)", texts):
url = matched.group(0)
+ # vertexaisearch.cloud.google.com
+ if matched := re.search(r"(https?://)?vertexaisearch\.cloud\.google\.com/([0-9a-zA-Z\-_=+/]+)", texts):
+ url = matched.group(0)
+ proxy = PROXY.GOOGLE_SEARCH
# custom pattern
if pattern and (matched := re.search(pattern, texts)):