Commit 426942b
Changed files (9)
src/llm/aigc.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import contextlib
+import random
+from io import BytesIO
+from pathlib import Path
+
+from glom import glom
+from google import genai
+from google.genai.types import ContentUnionDict, GenerateContentConfig, HttpOptions, Part
+from loguru import logger
+from PIL import Image
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import AIGC, DOWNLOAD_DIR, PREFIX
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_source_marks
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from utils import rand_string
+
+HELP = f"""🌠**AIGC**
+`{PREFIX.GENIMG}` 后接提示词即可生成
+回复消息可继续对话重新修改生成结果
+
+⚙️模型配置:
+🏞生图模型: **{AIGC.IMG_MODEL}
+
+⚠️目前只支持生成图片
+"""
+
+
+async def aigc(client: Client, message: Message, contexts: list[dict], modality: str = "image", **kwargs):
+ r"""Get AIGC response.
+
+ contexts: [
+ {
+ "role": role, # assistant or user
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+ {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+ ]
+ }
+ ]
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ contexts (list[dict]): Parsed from chat history.
+ modality (str): response modality
+ """
+ # ruff: noqa: RET502, RET503
+ info = parse_msg(message)
+ api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ random.choice(api_keys)
+ response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+ res = {}
+ try:
+ app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+ count_tokens = await app.aio.models.count_tokens(model=AIGC.IMG_MODEL, contents=info["text"])
+ num_token = count_tokens.total_tokens or 0
+ if num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
+ await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
+ return
+
+ msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 正在生成..."
+ status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+ kwargs["progress"] = status_msg
+ gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
+ gemini_logging(gemini_contexts)
+ response = await app.aio.models.generate_content(
+ model=AIGC.IMG_MODEL,
+ contents=gemini_contexts,
+ config=GenerateContentConfig(response_modalities=response_modalities),
+ )
+ res = parse_response(glom(response.model_dump(), "candidates.0.content.parts"), model_name=AIGC.IMG_MODEL_NAME)
+ except Exception as e:
+ logger.error(e)
+ error = str(e)
+ if "res" in locals():
+ error += f"\n{res}"
+ if "response" in locals():
+ error += f"\n{response}"
+ await modify_progress(text=error, force_update=True, **kwargs)
+ return await send2tg(client, message, **res, **kwargs)
+
+
+def parse_response(data: list[dict], model_name: str) -> dict:
+ gemini_logging(data)
+ texts = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+ media = []
+ for item in data:
+ if item.get("text") is not None:
+ texts += f"{item['text'].strip()}\n"
+ if item.get("inline_data") is not None:
+ image = Image.open(BytesIO(item["inline_data"]["data"]))
+ mime = item["inline_data"]["mime_type"]
+ ext = mime.split("/")[-1]
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
+ image.save(save_path)
+ media.append({"photo": save_path})
+ return {"texts": beautify_llm_response(texts, newline_level=2), "media": media}
+
+
+def openai_context_to_gemini(context: dict) -> ContentUnionDict:
+ r"""Convert OpenAI context to Gemini format.
+
+ Args:
+ context (dict): {
+ "role": role, # assistant or user
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+ {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+ ]
+ }
+
+ Returns:
+ dict: {
+ "role": role, # model or user
+ "parts: [
+ {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
+ {"text": "hello"}
+ ]
+ }
+ """
+ parts: list[Part] = []
+ role = "model" if context["role"] == "assistant" else "user"
+ for item in context["content"]:
+ if item["type"] == "text":
+ parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+ elif item["type"] == "image_url":
+ data = item["image_url"]["url"].split(";base64,")
+ mime = data[0].removeprefix("data:")
+ parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+
+ return {"role": role, "parts": parts} # type: ignore
+
+
+def gemini_logging(contexts: list):
+ msg = ""
+ with contextlib.suppress(Exception):
+ for item in contexts:
+ role = item.get("role", "").upper() or "MODEL"
+
+ # Request
+ for part in item.get("parts", []):
+ if part.inline_data:
+ msg += f"[{role}]: Blob_Data "
+ if part.text:
+ msg += f"[{role}]: {part.text} "
+ # Response
+ if item.get("text", ""):
+ msg += f"[{role}]: {item['text']} "
+ if item.get("inline_data", ""):
+ msg += f"[{role}]: Blob_Data "
+
+ logger.debug(f"{msg!r}")
src/llm/contexts.py
@@ -69,7 +69,7 @@ async def single_context(client: Client, message: Message) -> dict:
def clean_text(text: str) -> str:
if not text:
return ""
- for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
+ for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
text = text.removeprefix(prefix).strip()
# remove bot tips
text = re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
src/llm/gpt.py
@@ -5,7 +5,11 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import GPT, PREFIX, TEXT_LENGTH, cache
+from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
+from llm.aigc import HELP as AIGC_HELP
+from llm.aigc import aigc
+
+# from llm.aigc import HELP as AIGC_HELP
from llm.contexts import get_conversation_contexts, get_conversations
from llm.models import get_context_type, get_gpt_config, parse_force_model
from llm.response import send_to_gpt
@@ -18,10 +22,9 @@ from messages.sender import send2tg
from messages.utils import count_without_entities, equal_prefix, startswith_prefix
HELP = f"""🤖**GPT对话**
-使用说明:
-1. `{PREFIX.GPT}` 后接提示词即可与GPT对话
-2. 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
-3. 暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
+`{PREFIX.GPT}` 后接提示词即可与GPT对话
+以 `{PREFIX.GPT}` 回复消息可将其加入上下文
+暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
⚙️模型配置:
`{PREFIX.GPT}`默认模型: **{GPT.TEXT_MODEL_NAME}**
@@ -45,7 +48,7 @@ def is_gpt_conversation(message: Message) -> bool:
info = parse_msg(message)
if info["is_bot"]: # do not process bot message
return False
- if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao", "/grok"]):
+ if startswith_prefix(info["text"], prefix=[PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/doubao", "/grok"]):
return True
# is replying to gpt-bot response message?
if not message.reply_to_message:
@@ -62,7 +65,8 @@ def is_gpt_conversation(message: Message) -> bool:
GPT.TEXT_MODEL_NAME,
GPT.IMAGE_MODEL_NAME,
]
- return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
+ aigc_names = [AIGC.IMG_MODEL_NAME]
+ return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in aigc_names])
async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -79,15 +83,17 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]) and not message.reply_to_message:
await send2tg(client, message, texts=HELP, **kwargs)
return
+ if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) and not message.reply_to_message:
+ await send2tg(client, message, texts=AIGC_HELP, **kwargs)
+ return
if not is_gpt_conversation(message):
return
-
reply_text = ""
if message.reply_to_message:
reply_info = parse_msg(message.reply_to_message, silent=True)
reply_text = reply_info["text"]
- force_model = parse_force_model(info["text"], reply_text)
+ force_model, modality = parse_force_model(info["text"], reply_text)
# cache media_group message, only process once
if media_group_id := message.media_group_id:
@@ -98,6 +104,8 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
conversations = get_conversations(message)
context_type = get_context_type(conversations)
contexts = await get_conversation_contexts(client, conversations)
+ if equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) or modality != "text":
+ return await aigc(client, message, contexts, modality, **kwargs)
config = get_gpt_config(context_type["type"], contexts, force_model)
if not config["client"]["api_key"]:
logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
src/llm/models.py
@@ -4,7 +4,7 @@
from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
-from config import GPT, PREFIX, PROXY
+from config import AIGC, GPT, PREFIX, PROXY
from messages.parser import parse_msg
from messages.utils import startswith_prefix
@@ -23,16 +23,17 @@ def get_context_type(conversations: list[Message]) -> dict:
if info["mtype"] == "audio":
has_audio = True
if has_audio or has_video:
- res["error"] = f"⚠️已忽略上下文中的视频/音频消息\n可以先用 `{PREFIX.ASR}` 命令转为文字后再使用 `{PREFIX.GPT}`"
+ res["error"] = f"⚠️已忽略上下文中的视频/音频消息\n可以先用 `{PREFIX.ASR}` 命令转为文字后再使用AI功能"
return res
-def parse_force_model(text: str, reply_text: str) -> str:
+def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
"""Parse the force model from the text or reply text.
/gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao, /grok = Grok
"""
force_model = ""
+ modality = "text"
# parse from bot reply
if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
force_model = GPT.OPENAI_MODEL
@@ -46,6 +47,9 @@ def parse_force_model(text: str, reply_text: str) -> str:
force_model = GPT.DOUBAO_MODEL
elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
force_model = GPT.GROK_MODEL
+ elif reply_text.startswith(f"🌠{AIGC.IMG_MODEL_NAME}"):
+ force_model = AIGC.IMG_MODEL
+ modality = "image"
# parse from command prefix
if startswith_prefix(text, prefix=["/gpt"]):
force_model = GPT.OPENAI_MODEL
@@ -59,7 +63,10 @@ def parse_force_model(text: str, reply_text: str) -> str:
force_model = GPT.DOUBAO_MODEL
elif startswith_prefix(text, prefix=["/grok"]):
force_model = GPT.GROK_MODEL
- return force_model
+ elif startswith_prefix(text, prefix=[PREFIX.GENIMG]):
+ force_model = AIGC.IMG_MODEL
+ modality = "image"
+ return force_model, modality
def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "") -> dict:
src/llm/utils.py
@@ -87,7 +87,7 @@ def beautify_model_name(name: str) -> str:
return name.replace("gpt", "GPT").replace("gemini", "Gemini").replace("deepseek", "DeepSeek") # GPT-4o
-def beautify_llm_response(text: str) -> str:
+def beautify_llm_response(text: str, newline_level: int = 3) -> str:
"""Beautify LLM response.
Args:
@@ -97,7 +97,19 @@ def beautify_llm_response(text: str) -> str:
"""
if not text:
return text
- # remove tags. should align with the tags in `contexts.py`
+ clean_text = clean_source_marks(text)
+ clean_text = remove_pound(clean_text)
+ clean_text = remove_dash(clean_text)
+ return remove_consecutive_newlines(clean_text, newline_level)
+
+
+def clean_source_marks(text: str) -> str:
+ """Remove [username], [message], ... marks.
+
+ Should align with the tags in `contexts.py`
+ """
+ if not text:
+ return text
clean_text = ""
for line in text.split("\n"):
if line.strip().startswith(("[username]:", "[filename]:")):
@@ -105,10 +117,7 @@ def beautify_llm_response(text: str) -> str:
if line.strip() in ["[message]:", "[file content]:"]:
continue
clean_text += line + "\n"
- clean_text = clean_text.removesuffix("\n") # remove the last newline
- clean_text = remove_pound(clean_text)
- clean_text = remove_dash(clean_text)
- return remove_consecutive_newlines(clean_text)
+ return clean_text.removesuffix("\n") # remove the last newline
def extract_reasoning(text: str) -> tuple[str, str]:
src/config.py
@@ -34,7 +34,6 @@ GOOGLE_SEARCH_GL = os.getenv("GOOGLE_SEARCH_GL", "cn") # "gl" parameter (Geoloc
class ENABLE: # see fine-grained permission in `src/permission.py`
- AI_SUMMARY = os.getenv("ENABLE_AI_SUMMARY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
ASR = os.getenv("ENABLE_ASR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
AUDIO = os.getenv("ENABLE_AUDIO", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
CRONTAB = os.getenv("ENABLE_CRONTAB", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -78,6 +77,7 @@ class PREFIX:
VOICE = os.getenv("PREFIX_VOICE", "/voice").lower()
SEARCH_YOUTUBE = os.getenv("PREFIX_SEARCH_YOUTUBE", "/ytb").lower()
SEARCH_GOOGLE = os.getenv("PREFIX_SEARCH_GOOGLE", "/google").lower()
+ GENIMG = os.getenv("PREFIX_GENIMG", "/gen").lower()
class API:
@@ -142,6 +142,7 @@ class COOKIE: # See: https://github.com/easychen/CookieCloud
class GPT: # see `llm/README.md`
+ # See class AIGC for the AIGC configurations
STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
@@ -263,3 +264,13 @@ class ASR:
TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None) # Banned oversea IP, need a back to China proxy
TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
+
+
+class AIGC:
+ # https://ai.google.dev/gemini-api/docs/image-generation
+ IMG_BASR_URL = os.getenv("AIGC_IMG_BASR_URL", "https://generativelanguage.googleapis.com/")
+ IMG_API_KEY = os.getenv("AIGC_IMG_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
+ IMG_MODEL = os.getenv("AIGC_IMG_MODEL", "gemini-2.0-flash-exp")
+ IMG_MODEL_NAME = os.getenv("AIGC_IMG_MODEL_NAME", "Gemini-2.0-Flash")
+ IMG_PROXY = os.getenv("AIGC_IMG_PROXY", None)
+ IMG_MAX_PROMPT_TOKEN = int(os.getenv("AIGC_IMG_MAX_PROMPT_TOKEN", "480"))
src/handler.py
@@ -43,13 +43,13 @@ async def handle_utilities(
ai: bool = True,
asr: bool = True,
audio: bool = True,
- ytb: bool = True,
google: bool = True,
- subtitle: bool = True,
- wget: bool = True,
ocr: bool = True,
price: bool = True,
+ subtitle: bool = True,
summary: bool = True,
+ wget: bool = True,
+ ytb: bool = True,
raw_img: bool = True,
show_progress: bool = True,
detail_progress: bool = False,
@@ -149,7 +149,32 @@ async def handle_social_media(
cmd_prefix.extend(PREFIX.MAIN)
ignore_prefix = ignore_prefix or ["/dl4dw"]
# these commands are handled in `handle_utilities`
- ignore_prefix.extend(["/ai", "/asr", "/audio", "/combine", "/doubao", "/ds", "/gemini", "/gpt", "/ocr", "/price", "/qwen", "/grok", "/subtitle", "/summary", "/voice", "/wget"])
+ ignore_prefix.extend(
+ [
+ "/doubao",
+ "/ds",
+ "/gemini",
+ "/gpt",
+ "/grok",
+ "/qwen",
+ PREFIX.ASR,
+ PREFIX.AI_SUMMARY,
+ PREFIX.AUDIO,
+ PREFIX.COMBINATION,
+ PREFIX.CONVERT,
+ PREFIX.CRYPTO,
+ PREFIX.GENIMG,
+ PREFIX.GPT,
+ PREFIX.OCR,
+ PREFIX.PRICE,
+ PREFIX.SEARCH_GOOGLE,
+ PREFIX.SEARCH_YOUTUBE,
+ PREFIX.STOCK,
+ PREFIX.SUBTITLE,
+ PREFIX.VOICE,
+ PREFIX.WGET,
+ ]
+ )
info = parse_msg(message)
this_texts = info["text"] # texts of the trigger message
if startswith_prefix(this_texts, prefix=ignore_prefix):
@@ -279,7 +304,9 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
msg += "\n🅱️哔哩哔哩"
msg += "\n🆕和所有yt-dlp支持的链接\n"
if permission["ai"]:
- msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok` + 提示词"
+ msg += f"\n🤖**AI对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok`"
+ msg += f"\n📖**AI总结**: `{PREFIX.AI_SUMMARY}` 总结历史聊天记录"
+ msg += f"\n🌠**AIGC**: `{PREFIX.GENIMG}`"
if permission["asr"]:
msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
if permission["audio"]:
@@ -290,8 +317,6 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
msg += f"\n💵**查询价格**: `{PREFIX.PRICE}` + Symbol"
if permission["subtitle"]:
msg += f"\n📃**提取字幕**: `{PREFIX.SUBTITLE}` + 油管链接 (或回复油管链接)"
- if permission["summary"] and permission["ai"]: # summary depends on ai
- msg += f"\n🤖**总结历史**: `{PREFIX.AI_SUMMARY}` AI总结历史聊天记录"
if permission["wget"]:
msg += f"\n⏬**下载文件**: `{PREFIX.WGET}` + URL"
if permission["ytb"]:
src/permission.py
@@ -102,7 +102,6 @@ def check_service(cid: int | str, ctype: str) -> dict:
"ocr": True,
"price": True,
"raw_img": True,
- "summary": True,
"ytb": True,
"google": True,
"show_progress": True,
@@ -153,8 +152,6 @@ def check_service(cid: int | str, ctype: str) -> dict:
permission["ocr"] = False
if not ENABLE.PRICE:
permission["price"] = False
- if not ENABLE.AI_SUMMARY:
- permission["summary"] = False
if not ENABLE.RAW_IMG_CONVERT:
permission["raw_img"] = False
src/utils.py
@@ -259,11 +259,14 @@ def remove_pound(text: str) -> str:
return text
-def remove_consecutive_newlines(text: str) -> str:
+def remove_consecutive_newlines(text: str, newline_level: int = 3) -> str:
if not text:
return ""
while "\n\n\n" in text:
text = text.replace("\n\n\n", "\n\n")
+ if newline_level == 2:
+ while "\n\n" in text:
+ text = text.replace("\n\n", "\n")
return text