Commit a681218
Changed files (4)
src
src/asr/gemini_asr.py
@@ -7,7 +7,7 @@ from pathlib import Path
from glom import glom
from google import genai
-from google.genai.types import GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
+from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig
from loguru import logger
from pydantic import BaseModel
from pyrogram.client import Client
@@ -27,7 +27,7 @@ async def gemini_stream_asr(
voice_format: str,
prompt: str = "请转录这段音频",
*,
- slient: bool = False,
+ silent: bool = False,
**kwargs,
) -> dict:
"""Gemini stream ASR.
@@ -35,7 +35,7 @@ async def gemini_stream_asr(
https://ai.google.dev/gemini-api/docs/audio
Args:
- slient (bool, optional): If Ture, do not update the status, return all results in the end.
+ silent (bool, optional): If Ture, do not update the status, return all results in the end.
"""
system_instruction = """You are a transcription assistant tasked with converting audio files into text.
@@ -68,7 +68,7 @@ Notes:
- Focus on accuracy in capturing both the timing and the spoken content.
- Maintain consistent formatting to ensure clarity and readability."""
path = Path(path)
- status = None if slient else kwargs.get("progress")
+ status = None if silent else kwargs.get("progress")
api_keys = shuffle_keys(GEMINI.API_KEYS)
for api_key in api_keys.split(","):
try:
@@ -85,9 +85,22 @@ Notes:
if GEMINI.ASR_THINKING_BUDGET is not None:
thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+ if GEMINI.ASR_USE_GROUNDING:
+ genconfig |= {"tools": [Tool(google_search=GoogleSearch())]}
contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
- res = await gemini_stream(client, message, model_name="ASR", params=params, prefix="", slient=slient, max_retry=0, gemini_api_key=api_key, **kwargs)
+ res = await gemini_stream(
+ client,
+ message,
+ model_name="ASR",
+ params=params,
+ prefix="",
+ silent=silent,
+ max_retry=0,
+ gemini_api_key=api_key,
+ append_grounding=False,
+ **kwargs,
+ )
if res.get("error"):
continue
sent_messages = res.get("sent_messages", [])
src/llm/gemini.py
@@ -34,7 +34,7 @@ HELP = f"""🌠**AI生图**
"""
-async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", **kwargs):
+async def gemini_response(client: Client, message: Message, conversations: list[Message], modality: str = "image", *, append_grounding: bool = True, **kwargs):
r"""Get Gemini response.
Args:
@@ -42,6 +42,7 @@ async def gemini_response(client: Client, message: Message, conversations: list[
message (Message): The trigger message object.
conversations (list[Message]): list of chat conversations.
modality (str): response modality
+ append_grounding (bool, optional): Whether to append grounding to the response. Defaults to True.
"""
info = parse_msg(message)
model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
@@ -73,8 +74,8 @@ async def gemini_response(client: Client, message: Message, conversations: list[
params = {"model": model, "contents": contexts, "config": GenerateContentConfig(**genconfig)}
logger.trace(params)
if modality == "image":
- return await gemini_nonstream(client, message, model_name, params, clean_marks=True, **kwargs)
- return await gemini_stream(client, message, model_name, params, **kwargs)
+ return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
+ return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, **kwargs)
except Exception as e:
logger.error(e)
@@ -90,6 +91,7 @@ async def gemini_stream(
last_error: str = "",
*,
silent: bool = False,
+ append_grounding: bool = True,
**kwargs,
) -> dict:
if prefix is None:
@@ -114,7 +116,7 @@ async def gemini_stream(
app = genai.Client(api_key=api_key, http_options=http_options)
sent_messages = []
async for chunk in await app.aio.models.generate_content_stream(**params):
- resp = parse_response(chunk.model_dump())
+ resp = parse_response(chunk.model_dump(), append_grounding=append_grounding)
answer = resp.get("texts", "")
runtime_texts += answer
answers += answer
@@ -136,7 +138,7 @@ async def gemini_stream(
# all chunks are processed
if not answers.strip(): # empty response
- return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, **kwargs)
+ return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, append_grounding=append_grounding, **kwargs)
if await count_without_entities(prefix + answers) <= TEXT_LENGTH: # short answer in single msg
if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
@@ -156,7 +158,7 @@ async def gemini_stream(
with contextlib.suppress(Exception):
await modify_progress(message=init_status_msg, text=error, force_update=True)
[await modify_progress(msg, del_status=True) for msg in sent_messages]
- return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, **kwargs)
+ return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, append_grounding=append_grounding, **kwargs)
return {"texts": answers, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
@@ -168,6 +170,7 @@ async def gemini_nonstream(
retry: int = 0,
*,
clean_marks: bool = False, # useful in image generation
+ append_grounding: bool = True,
**kwargs,
):
try:
@@ -184,7 +187,7 @@ async def gemini_nonstream(
app = genai.Client(api_key=api_key, http_options=http_options)
response = await app.aio.models.generate_content(**params)
prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- res = parse_response(response.model_dump())
+ res = parse_response(response.model_dump(), append_grounding=append_grounding)
texts = res.get("texts", "")
media = res.get("media", [])
length = await count_without_entities(prefix + texts)
@@ -200,15 +203,17 @@ async def gemini_nonstream(
if "response" in locals():
error += f"\n{response}"
await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_nonstream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
+ return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs) # type: ignore
-def parse_response(data: dict) -> dict:
+def parse_response(data: dict, *, append_grounding: bool = True) -> dict:
"""Parse gemini response, includes texts, image and websearch."""
logger.trace(data)
parts = glom(data, "candidates.0.content.parts", default=[]) or []
gemini_logging(parts)
grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
+ if not append_grounding:
+ grounding_chunks = []
texts = ""
media = []
for item in parts:
src/preview/ytdlp.py
@@ -261,7 +261,7 @@ async def preview_ytdlp(
subtitles = res.get("subtitles", "")
if not subtitles:
ytdlp_transcription_engine = "gemini" if "youtube" in info["extractor"] else ytdlp_transcription_engine # use gemini to bypass censorship
- res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, slient=True)
+ res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, silent=True)
subtitles = res.get("texts", "")
if subtitles:
if len(subtitles) > TEXT_LENGTH or transcription_force_file:
src/config.py
@@ -297,3 +297,4 @@ class GEMINI: # Official Gemini
ASR_MODEL = os.getenv("GEMINI_ASR_MODEL", "gemini-2.5-flash-preview-04-17")
ASR_THINKING_BUDGET = os.getenv("GEMINI_ASR_THINKING_BUDGET", None) # 0 to disable thinking. DO NOT set this if the model is not a thinking model
ASR_CONFIG = os.getenv("GEMINI_ASR_CONFIG", "{}") # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
+ ASR_USE_GROUNDING = os.getenv("GEMINI_ASR_USE_GROUNDING", "1").lower() in ["1", "y", "yes", "t", "true", "on"] # Use Grounding with Google Search