Commit 6629619
Changed files (6)
src
src/llm/aigc.py
@@ -8,18 +8,18 @@ from pathlib import Path
from glom import glom
from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, HttpOptions, Part
+from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
from loguru import logger
from PIL import Image
from pyrogram.client import Client
from pyrogram.types import Message
from config import AIGC, DOWNLOAD_DIR, PREFIX
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_source_marks
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
-from utils import rand_string
+from utils import number_to_emoji, rand_string
HELP = f"""🌠**AIGC**
`{PREFIX.GENIMG}` 后接提示词即可生成
@@ -56,6 +56,7 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
random.choice(api_keys)
response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+ tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
res = {}
try:
app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
@@ -65,7 +66,7 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
return
- msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 正在生成..."
+ msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 思考中...\n{clean_prefix(info['text'])}"
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
@@ -73,9 +74,12 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
response = await app.aio.models.generate_content(
model=AIGC.IMG_MODEL,
contents=gemini_contexts,
- config=GenerateContentConfig(response_modalities=response_modalities),
+ config=GenerateContentConfig(
+ response_modalities=response_modalities,
+ tools=tools, # type: ignore
+ ),
)
- res = parse_response(glom(response.model_dump(), "candidates.0.content.parts"), model_name=AIGC.IMG_MODEL_NAME)
+ res = parse_response(glom(response.model_dump(), "candidates.0"), model_name=AIGC.IMG_MODEL_NAME)
except Exception as e:
logger.error(e)
error = str(e)
@@ -83,17 +87,21 @@ async def aigc(client: Client, message: Message, contexts: list[dict], modality:
error += f"\n{res}"
if "response" in locals():
error += f"\n{response}"
- await modify_progress(text=error, force_update=True, **kwargs)
- return await send2tg(client, message, **res, **kwargs)
+ return await modify_progress(text=error, force_update=True, **kwargs)
+ await send2tg(client, message, caption_above=True, **res, **kwargs)
+ await modify_progress(del_status=True, **kwargs)
-def parse_response(data: list[dict], model_name: str) -> dict:
- gemini_logging(data)
- texts = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+def parse_response(data: dict, model_name: str) -> dict:
+ parts = glom(data, "content.parts", default=[]) or []
+ gemini_logging(parts)
+ grounding_chunks = glom(data, "grounding_metadata.grounding_chunks", default=[]) or []
+ texts = ""
+ prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
media = []
- for item in data:
+ for item in parts:
if item.get("text") is not None:
- texts += f"{item['text'].strip()}\n"
+ texts += item["text"]
if item.get("inline_data") is not None:
image = Image.open(BytesIO(item["inline_data"]["data"]))
mime = item["inline_data"]["mime_type"]
@@ -101,7 +109,11 @@ def parse_response(data: list[dict], model_name: str) -> dict:
save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
image.save(save_path)
media.append({"photo": save_path})
- return {"texts": beautify_llm_response(texts, newline_level=2), "media": media}
+ for idx, grounding in enumerate(grounding_chunks):
+ title = glom(grounding, "web.title", default="Web")
+ url = glom(grounding, "web.uri", default="https://www.google.com")
+ texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+ return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
def openai_context_to_gemini(context: dict) -> ContentUnionDict:
src/llm/contexts.py
@@ -10,8 +10,8 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import FILE_SERVER, GPT, PREFIX
-from llm.utils import BOT_TIPS
+from config import FILE_SERVER, GPT
+from llm.utils import BOT_TIPS, clean_prefix
from messages.parser import parse_msg
if TYPE_CHECKING:
@@ -69,8 +69,7 @@ async def single_context(client: Client, message: Message) -> dict:
def clean_text(text: str) -> str:
if not text:
return ""
- for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
- text = text.removeprefix(prefix).strip()
+ text = clean_prefix(text)
# remove bot tips
text = re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
# remove reasoning
src/llm/gpt.py
@@ -15,7 +15,7 @@ from llm.models import get_context_type, get_gpt_config, parse_force_model
from llm.response import send_to_gpt
from llm.response_stream import send_to_gpt_stream
from llm.tools import merge_tools_response
-from llm.utils import BOT_TIPS, image_emoji, llm_cleanup_files
+from llm.utils import BOT_TIPS, clean_prefix, image_emoji, llm_cleanup_files
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -110,7 +110,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
if not config["client"]["api_key"]:
logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
- msg = f"🤖**{config['friendly_name']}**: 思考中..."
+ msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
if context_type.get("error"):
src/llm/utils.py
@@ -6,7 +6,7 @@ from pathlib import Path
import tiktoken
from loguru import logger
-from config import DOWNLOAD_DIR, GPT
+from config import DOWNLOAD_DIR, GPT, PREFIX
from utils import number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound
BOT_TIPS = "回复以继续"
@@ -164,3 +164,9 @@ def add_search_results_to_response(search_results: list[dict], response: str) ->
def image_emoji(capability: bool) -> str: # noqa: FBT001
"""Get image capability emoji."""
return "(🏞)" if capability else ""
+
+
+def clean_prefix(text: str) -> str:
+ for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
+ text = text.removeprefix(prefix).lstrip()
+ return text
src/messages/preprocess.py
@@ -113,7 +113,7 @@ def preprocess_media(media: list[dict]) -> list[dict]:
return done_audios
-async def warp_media_group(media: list[dict], caption: str = "") -> list:
+async def warp_media_group(media: list[dict], caption: str = "", *, caption_above: bool = False) -> list:
"""Warp media files into a list of media group objects.
item in media:
@@ -148,10 +148,10 @@ async def warp_media_group(media: list[dict], caption: str = "") -> list:
media = media[:10]
# add caption to the first item
if media[0].get("photo"):
- group.append(InputMediaPhoto(media[0]["photo"], caption=caption))
+ group.append(InputMediaPhoto(media[0]["photo"], caption=caption, show_caption_above_media=caption_above))
elif media[0].get("video"):
media[0]["media"] = media[0].pop("video")
- group.append(InputMediaVideo(caption=caption, **media[0]))
+ group.append(InputMediaVideo(caption=caption, show_caption_above_media=caption_above, **media[0]))
elif media[0].get("audio"):
media[0]["media"] = media[0].pop("audio")
group.append(InputMediaAudio(caption=caption, **media[0]))
src/messages/sender.py
@@ -28,6 +28,7 @@ async def send2tg(
comments: list[str] | None = None, # append after texts
send_from_user: str | None = None,
cooldown: float = 0,
+ caption_above: bool = False,
**kwargs,
) -> list[Message | None]:
"""Send unlimited number of texts and media to Telegram.
@@ -49,6 +50,7 @@ async def send2tg(
comments (list[str], optional): The comments to append after texts.
send_from_user (str, optional): The user name to prefix the texts.
cooldown (float, optional): The interval between each media message. Defaults to 0.
+ caption_above (bool, optional): Show caption above the message media.
kwargs: Other keyword arguments. In this function, we use:
show_progress (bool, optional): Show a progress message on Telegram. Defaults to True.
detail_progress (bool, optional): Show detailed progress (Only if show_proress is set to True). Defaults to False.
@@ -82,13 +84,13 @@ async def send2tg(
if len(media) == 0:
return await send_texts(client, target_chat, reply_parameters, texts=texts, cooldown=cooldown)
if len(media) == 1:
- return await send_single_media(client, target_chat, reply_parameters, media=media[0], texts=texts, cooldown=cooldown, **kwargs)
+ return await send_single_media(client, target_chat, reply_parameters, media=media[0], texts=texts, cooldown=cooldown, caption_above=caption_above, **kwargs)
caption = (await smart_split(texts, CAPTION_LENGTH))[0]
remaining_texts = texts.removeprefix(caption)
caption = warp_comments(caption)
if 1 < len(media) <= 10:
- group = await warp_media_group(media, caption=caption)
+ group = await warp_media_group(media, caption=caption, caption_above=caption_above)
sent_messages.extend(await send_media_group(client, target_chat, group, reply_parameters))
else: # media > 10
media_chunks = [media[i : i + 10] for i in range(0, len(media), 10)]
@@ -102,7 +104,7 @@ async def send2tg(
group = await warp_media_group(batch)
sent_messages.extend(await send_media_group(client, target_chat, group, ReplyParameters()))
else: # last chunk: media <= 10, add caption here
- sent_messages.extend(await send2tg(client, message, target_chat, reply_msg_id=-1, texts=caption, media=batch, cooldown=cooldown, **kwargs))
+ sent_messages.extend(await send2tg(client, message, target_chat, reply_msg_id=-1, texts=caption, media=batch, caption_above=caption_above, cooldown=cooldown, **kwargs))
await asyncio.sleep(cooldown)
if remaining_texts:
sent_messages.extend(await send_texts(client, target_chat, ReplyParameters(), texts=remaining_texts, cooldown=cooldown))
@@ -159,6 +161,7 @@ async def send_single_media(
media: dict,
texts: str = "",
cooldown: float = 0,
+ caption_above: bool = False,
**kwargs,
) -> list[Message | None]:
sent_messages: list[Message | None] = []
@@ -169,12 +172,13 @@ async def send_single_media(
message = None
try:
if photo := media.get("photo"):
- message = await client.send_photo(chat_id=target_chat, photo=photo, caption=caption, reply_parameters=reply_parameters)
+ message = await client.send_photo(chat_id=target_chat, photo=photo, caption=caption, show_caption_above_media=caption_above, reply_parameters=reply_parameters)
elif video := media.get("video"):
message = await client.send_video(
chat_id=target_chat,
reply_parameters=reply_parameters,
caption=caption,
+ show_caption_above_media=caption_above,
progress=telegram_uploading,
progress_args=(kwargs.get("progress", False), video, kwargs.get("detail_progress", True)),
**media,