Commit a681218

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-13 09:54:44
feat(gemini): add `append_grounding` option
1 parent 410f047
Changed files (4)
src/asr/gemini_asr.py
@@ -7,7 +7,7 @@ from pathlib import Path
 
 from glom import glom
 from google import genai
-from google.genai.types import GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
+from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig
 from loguru import logger
 from pydantic import BaseModel
 from pyrogram.client import Client
@@ -27,7 +27,7 @@ async def gemini_stream_asr(
     voice_format: str,
     prompt: str = "请转录这段音频",
     *,
-    slient: bool = False,
+    silent: bool = False,
     **kwargs,
 ) -> dict:
     """Gemini stream ASR.
@@ -35,7 +35,7 @@ async def gemini_stream_asr(
     https://ai.google.dev/gemini-api/docs/audio
 
     Args:
-        slient (bool, optional): If Ture, do not update the status, return all results in the end.
+        silent (bool, optional): If Ture, do not update the status, return all results in the end.
     """
     system_instruction = """You are a transcription assistant tasked with converting audio files into text.
 
@@ -68,7 +68,7 @@ Notes:
 - Focus on accuracy in capturing both the timing and the spoken content.
 - Maintain consistent formatting to ensure clarity and readability."""
     path = Path(path)
-    status = None if slient else kwargs.get("progress")
+    status = None if silent else kwargs.get("progress")
     api_keys = shuffle_keys(GEMINI.API_KEYS)
     for api_key in api_keys.split(","):
         try:
@@ -85,9 +85,22 @@ Notes:
             if GEMINI.ASR_THINKING_BUDGET is not None:
                 thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
                 genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+            if GEMINI.ASR_USE_GROUNDING:
+                genconfig |= {"tools": [Tool(google_search=GoogleSearch())]}
             contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
             params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
-            res = await gemini_stream(client, message, model_name="ASR", params=params, prefix="", slient=slient, max_retry=0, gemini_api_key=api_key, **kwargs)
+            res = await gemini_stream(
+                client,
+                message,
+                model_name="ASR",
+                params=params,
+                prefix="",
+                silent=silent,
+                max_retry=0,
+                gemini_api_key=api_key,
+                append_grounding=False,
+                **kwargs,
+            )
             if res.get("error"):
                 continue
             sent_messages = res.get("sent_messages", [])
src/llm/gemini.py
@@ -34,7 +34,7 @@ HELP = f"""🌠**AI生图**
 """
 
 
-async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", **kwargs):
+async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", *, append_grounding: bool = True, **kwargs):
     r"""Get Gemini response.
 
     Args:
@@ -42,6 +42,7 @@ async def gemini_response(client: Client, message: Message, conversations: list[
         message (Message): The trigger message object.
         conversations (list[Message]): list of chat conversations.
         modality (str): response modality
+        append_grounding (bool, optional): Whether to append grounding to the response. Defaults to True.
     """
     info = parse_msg(message)
     model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
@@ -73,8 +74,8 @@ async def gemini_response(client: Client, message: Message, conversations: list[
         params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
         logger.trace(params)
         if modality == "image":
-            return await gemini_nonstream(client, message, model_name, params, clean_marks=True, **kwargs)
-        return await gemini_stream(client, message, model_name, params, **kwargs)
+            return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
+        return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, **kwargs)
     except Exception as e:
         logger.error(e)
 
@@ -90,6 +91,7 @@ async def gemini_stream(
     last_error: str = "",
     *,
     silent: bool = False,
+    append_grounding: bool = True,
     **kwargs,
 ) -> dict:
     if prefix is None:
@@ -114,7 +116,7 @@ async def gemini_stream(
         app = genai.Client(api_key=api_key, http_options=http_options)
         sent_messages = []
         async for chunk in await app.aio.models.generate_content_stream(**params):
-            resp = parse_response(chunk.model_dump())
+            resp = parse_response(chunk.model_dump(), append_grounding=append_grounding)
             answer = resp.get("texts", "")
             runtime_texts += answer
             answers += answer
@@ -136,7 +138,7 @@ async def gemini_stream(
 
         # all chunks are processed
         if not answers.strip():  # empty response
-            return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, **kwargs)
+            return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, append_grounding=append_grounding, **kwargs)
 
         if await count_without_entities(prefix + answers) <= TEXT_LENGTH:  # short answer in single msg
             if length > GPT.COLLAPSE_LENGTH:  # collapse the response if the answer is too long
@@ -156,7 +158,7 @@ async def gemini_stream(
         with contextlib.suppress(Exception):
             await modify_progress(message=init_status_msg, text=error, force_update=True)
             [await modify_progress(msg, del_status=True) for msg in sent_messages]
-        return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, **kwargs)
+        return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, append_grounding=append_grounding, **kwargs)
     return {"texts": answers, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
 
 
@@ -168,6 +170,7 @@ async def gemini_nonstream(
     retry: int = 0,
     *,
     clean_marks: bool = False,  # useful in image generation
+    append_grounding: bool = True,
     **kwargs,
 ):
     try:
@@ -184,7 +187,7 @@ async def gemini_nonstream(
         app = genai.Client(api_key=api_key, http_options=http_options)
         response = await app.aio.models.generate_content(**params)
         prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
-        res = parse_response(response.model_dump())
+        res = parse_response(response.model_dump(), append_grounding=append_grounding)
         texts = res.get("texts", "")
         media = res.get("media", [])
         length = await count_without_entities(prefix + texts)
@@ -200,15 +203,17 @@ async def gemini_nonstream(
         if "response" in locals():
             error += f"\n{response}"
         await modify_progress(text=error, force_update=True, **kwargs)
-        return await gemini_nonstream(client, message, model_name, params, retry + 1, **kwargs)  # type: ignore
+        return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs)  # type: ignore
 
 
-def parse_response(data: dict) -> dict:
+def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
     """Parse gemini response, includes texts, image and websearch."""
     logger.trace(data)
     parts = glom(data, "candidates.0.content.parts", default=[]) or []
     gemini_logging(parts)
     grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
+    if not append_grounding:
+        grounding_chunks = []
     texts = ""
     media = []
     for item in parts:
src/preview/ytdlp.py
@@ -261,7 +261,7 @@ async def preview_ytdlp(
         subtitles = res.get("subtitles", "")
         if not subtitles:
             ytdlp_transcription_engine = "gemini" if "youtube" in info["extractor"] else ytdlp_transcription_engine  # use gemini to bypass censorship
-            res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, slient=True)
+            res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, silent=True)
             subtitles = res.get("texts", "")
         if subtitles:
             if len(subtitles) > TEXT_LENGTH or transcription_force_file:
src/config.py
@@ -297,3 +297,4 @@ class GEMINI:  # Official Gemini
     ASR_MODEL = os.getenv("GEMINI_ASR_MODEL", "gemini-2.5-flash-preview-04-17")
     ASR_THINKING_BUDGET = os.getenv("GEMINI_ASR_THINKING_BUDGET", None)  # 0 to disable thinking. DO NOT set this if the model is not a thinking model
     ASR_CONFIG = os.getenv("GEMINI_ASR_CONFIG", "{}")  # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+    ASR_USE_GROUNDING = os.getenv("GEMINI_ASR_USE_GROUNDING", "1").lower() in ["1", "y", "yes", "t", "true", "on"]  # Use Grounding with Google Search