Commit 6629619

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-28 05:41:09
chore(aigc): show caption above the media
1 parent 426942b
src/llm/aigc.py
@@ -8,18 +8,18 @@ from pathlib import Path
 
 from glom import glom
 from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, HttpOptions, Part
+from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
 from loguru import logger
 from PIL import Image
 from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import AIGC, DOWNLOAD_DIR, PREFIX
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_source_marks
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
-from utils import rand_string
+from utils import number_to_emoji, rand_string
 
 HELP = f"""🌠**AIGC**
 `{PREFIX.GENIMG}` 后接提示词即可生成
@@ -56,6 +56,7 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
     api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
     random.choice(api_keys)
     response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+    tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
     res = {}
     try:
         app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
@@ -65,7 +66,7 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
             await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
             return
 
-        msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 正在生成..."
+        msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 思考中...\n{clean_prefix(info['text'])}"
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
         gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
@@ -73,9 +74,12 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
         response = await app.aio.models.generate_content(
             model=AIGC.IMG_MODEL,
             contents=gemini_contexts,
-            config=GenerateContentConfig(response_modalities=response_modalities),
+            config=GenerateContentConfig(
+                response_modalities=response_modalities,
+                tools=tools,  # type: ignore
+            ),
         )
-        res = parse_response(glom(response.model_dump(), "candidates.0.content.parts"), model_name=AIGC.IMG_MODEL_NAME)
+        res = parse_response(glom(response.model_dump(), "candidates.0"), model_name=AIGC.IMG_MODEL_NAME)
     except Exception as e:
         logger.error(e)
         error = str(e)
@@ -83,17 +87,21 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
             error += f"\n{res}"
         if "response" in locals():
             error += f"\n{response}"
-        await modify_progress(text=error, force_update=True, **kwargs)
-    return await send2tg(client, message, **res, **kwargs)
+        return await modify_progress(text=error, force_update=True, **kwargs)
+    await send2tg(client, message, caption_above=True, **res, **kwargs)
+    await modify_progress(del_status=True, **kwargs)
 
 
-def parse_response(data: list[dict], model_name: str) -> dict:
-    gemini_logging(data)
-    texts = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+def parse_response(data: dict, model_name: str) -> dict:
+    parts = glom(data, "content.parts", default=[]) or []
+    gemini_logging(parts)
+    grounding_chunks = glom(data, "grounding_metadata.grounding_chunks", default=[]) or []
+    texts = ""
+    prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
     media = []
-    for item in data:
+    for item in parts:
         if item.get("text") is not None:
-            texts += f"{item['text'].strip()}\n"
+            texts += item["text"]
         if item.get("inline_data") is not None:
             image = Image.open(BytesIO(item["inline_data"]["data"]))
             mime = item["inline_data"]["mime_type"]
@@ -101,7 +109,11 @@ def parse_response(data: list[dict], model_name: str) -> dict:
             save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
             image.save(save_path)
             media.append({"photo": save_path})
-    return {"texts": beautify_llm_response(texts, newline_level=2), "media": media}
+    for idx, grounding in enumerate(grounding_chunks):
+        title = glom(grounding, "web.title", default="Web")
+        url = glom(grounding, "web.uri", default="https://www.google.com")
+        texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+    return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
 
 
 def openai_context_to_gemini(context: dict) -> ContentUnionDict:
src/llm/contexts.py
@@ -10,8 +10,8 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import FILE_SERVER, GPT, PREFIX
-from llm.utils import BOT_TIPS
+from config import FILE_SERVER, GPT
+from llm.utils import BOT_TIPS, clean_prefix
 from messages.parser import parse_msg
 
 if TYPE_CHECKING:
@@ -69,8 +69,7 @@ async def single_context(client: Client, message: Message) -> dict:
     def clean_text(text: str) -> str:
         if not text:
             return ""
-        for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
-            text = text.removeprefix(prefix).strip()
+        text = clean_prefix(text)
         # remove bot tips
         text = re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
         # remove reasoning
src/llm/gpt.py
@@ -15,7 +15,7 @@ from llm.models import get_context_type, get_gpt_config, parse_force_model
 from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
 from llm.tools import merge_tools_response
-from llm.utils import BOT_TIPS, image_emoji, llm_cleanup_files
+from llm.utils import BOT_TIPS, clean_prefix, image_emoji, llm_cleanup_files
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -110,7 +110,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     if not config["client"]["api_key"]:
         logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
         return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
-    msg = f"🤖**{config['friendly_name']}**: 思考中..."
+    msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
     if context_type.get("error"):
src/llm/utils.py
@@ -6,7 +6,7 @@ from pathlib import Path
 import tiktoken
 from loguru import logger
 
-from config import DOWNLOAD_DIR, GPT
+from config import DOWNLOAD_DIR, GPT, PREFIX
 from utils import number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound
 
 BOT_TIPS = "回复以继续"
@@ -164,3 +164,9 @@ def add_search_results_to_response(search_results: list[dict], response: str) ->
 def image_emoji(capability: bool) -> str:  # noqa: FBT001
     """Get image capability emoji."""
     return "(🏞)" if capability else ""
+
+
+def clean_prefix(text: str) -> str:
+    for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
+        text = text.removeprefix(prefix).lstrip()
+    return text
src/messages/preprocess.py
@@ -113,7 +113,7 @@ def preprocess_media(media: list[dict]) -> list[dict]:
     return done_audios
 
 
-async def warp_media_group(media: list[dict], caption: str = "") -> list:
+async def warp_media_group(media: list[dict], caption: str = "", *, caption_above: bool = False) -> list:
     """Warp media files into a list of media group objects.
 
     item in media:
@@ -148,10 +148,10 @@ async def warp_media_group(media: list[dict], caption: str = "") -> list:
         media = media[:10]
     # add caption to the first item
     if media[0].get("photo"):
-        group.append(InputMediaPhoto(media[0]["photo"], caption=caption))
+        group.append(InputMediaPhoto(media[0]["photo"], caption=caption, show_caption_above_media=caption_above))
     elif media[0].get("video"):
         media[0]["media"] = media[0].pop("video")
-        group.append(InputMediaVideo(caption=caption, **media[0]))
+        group.append(InputMediaVideo(caption=caption, show_caption_above_media=caption_above, **media[0]))
     elif media[0].get("audio"):
         media[0]["media"] = media[0].pop("audio")
         group.append(InputMediaAudio(caption=caption, **media[0]))
src/messages/sender.py
@@ -28,6 +28,7 @@ async def send2tg(
     comments: list[str] | None = None,  # append after texts
     send_from_user: str | None = None,
     cooldown: float = 0,
+    caption_above: bool = False,
     **kwargs,
 ) -> list[Message | None]:
     """Send unlimited number of texts and media to Telegram.
@@ -49,6 +50,7 @@ async def send2tg(
         comments (list[str], optional): The comments to append after texts.
         send_from_user (str, optional): The user name to prefix the texts.
         cooldown (float, optional): The interval between each media message. Defaults to 0.
+        caption_above (bool, optional): Show caption above the message media.
         kwargs: Other keyword arguments. In this function, we use:
             show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
             detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
@@ -82,13 +84,13 @@ async def send2tg(
     if len(media) == 0:
         return await send_texts(client, target_chat, reply_parameters, texts=texts, cooldown=cooldown)
     if len(media) == 1:
-        return await send_single_media(client, target_chat, reply_parameters, media=media[0], texts=texts, cooldown=cooldown, **kwargs)
+        return await send_single_media(client, target_chat, reply_parameters, media=media[0], texts=texts, cooldown=cooldown, caption_above=caption_above, **kwargs)
 
     caption = (await smart_split(texts, CAPTION_LENGTH))[0]
     remaining_texts = texts.removeprefix(caption)
     caption = warp_comments(caption)
     if 1 < len(media) <= 10:
-        group = await warp_media_group(media, caption=caption)
+        group = await warp_media_group(media, caption=caption, caption_above=caption_above)
         sent_messages.extend(await send_media_group(client, target_chat, group, reply_parameters))
     else:  # media > 10
         media_chunks = [media[i : i + 10] for i in range(0, len(media), 10)]
@@ -102,7 +104,7 @@ async def send2tg(
                 group = await warp_media_group(batch)
                 sent_messages.extend(await send_media_group(client, target_chat, group, ReplyParameters()))
             else:  # last chunk:  media <= 10, add caption here
-                sent_messages.extend(await send2tg(client, message, target_chat, reply_msg_id=-1, texts=caption, media=batch, cooldown=cooldown, **kwargs))
+                sent_messages.extend(await send2tg(client, message, target_chat, reply_msg_id=-1, texts=caption, media=batch, caption_above=caption_above, cooldown=cooldown, **kwargs))
             await asyncio.sleep(cooldown)
     if remaining_texts:
         sent_messages.extend(await send_texts(client, target_chat, ReplyParameters(), texts=remaining_texts, cooldown=cooldown))
@@ -159,6 +161,7 @@ async def send_single_media(
     media: dict,
     texts: str = "",
     cooldown: float = 0,
+    caption_above: bool = False,
     **kwargs,
 ) -> list[Message | None]:
     sent_messages: list[Message | None] = []
@@ -169,12 +172,13 @@ async def send_single_media(
     message = None
     try:
         if photo := media.get("photo"):
-            message = await client.send_photo(chat_id=target_chat, photo=photo, caption=caption, reply_parameters=reply_parameters)
+            message = await client.send_photo(chat_id=target_chat, photo=photo, caption=caption, show_caption_above_media=caption_above, reply_parameters=reply_parameters)
         elif video := media.get("video"):
             message = await client.send_video(
                 chat_id=target_chat,
                 reply_parameters=reply_parameters,
                 caption=caption,
+                show_caption_above_media=caption_above,
                 progress=telegram_uploading,
                 progress_args=(kwargs.get("progress", False), video, kwargs.get("detail_progress", True)),
                 **media,