Commit 426942b

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-28 04:35:48
feat(aigc): add `/gen` command for AIGC image generation
1 parent b66f56d
src/llm/aigc.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import contextlib
+import random
+from io import BytesIO
+from pathlib import Path
+
+from glom import glom
+from google import genai
+from google.genai.types import ContentUnionDict, GenerateContentConfig, HttpOptions, Part
+from loguru import logger
+from PIL import Image
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import AIGC, DOWNLOAD_DIR, PREFIX
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_source_marks
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from utils import rand_string
+
+HELP = f"""🌠**AIGC**
+`{PREFIX.GENIMG}` 后接提示词即可生成
+回复消息可继续对话重新修改生成结果
+
+⚙️模型配置:
+🏞生图模型: **{AIGC.IMG_MODEL}
+
+⚠️目前只支持生成图片
+"""
+
+
+async def aigc(client: Client, message: Message, contexts: list[dict], modality: str = "image", **kwargs):
+    r"""Get AIGC response.
+
+    contexts: [
+                {
+                "role": role,  # assistant or user
+                "content": [
+                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+                    ]
+                }
+            ]
+
+    Args:
+        client (Client): The Pyrogram client.
+        message (Message): The trigger message object.
+        contexts (list[dict]): Parsed from chat history.
+        modality (str): response modality
+    """
+    # ruff: noqa: RET502, RET503
+    info = parse_msg(message)
+    api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+    random.choice(api_keys)
+    response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+    res = {}
+    try:
+        app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+        count_tokens = await app.aio.models.count_tokens(model=AIGC.IMG_MODEL, contents=info["text"])
+        num_token = count_tokens.total_tokens or 0
+        if num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
+            await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
+            return
+
+        msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 正在生成..."
+        status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+        kwargs["progress"] = status_msg
+        gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
+        gemini_logging(gemini_contexts)
+        response = await app.aio.models.generate_content(
+            model=AIGC.IMG_MODEL,
+            contents=gemini_contexts,
+            config=GenerateContentConfig(response_modalities=response_modalities),
+        )
+        res = parse_response(glom(response.model_dump(), "candidates.0.content.parts"), model_name=AIGC.IMG_MODEL_NAME)
+    except Exception as e:
+        logger.error(e)
+        error = str(e)
+        if "res" in locals():
+            error += f"\n{res}"
+        if "response" in locals():
+            error += f"\n{response}"
+        await modify_progress(text=error, force_update=True, **kwargs)
+    return await send2tg(client, message, **res, **kwargs)
+
+
+def parse_response(data: list[dict], model_name: str) -> dict:
+    gemini_logging(data)
+    texts = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+    media = []
+    for item in data:
+        if item.get("text") is not None:
+            texts += f"{item['text'].strip()}\n"
+        if item.get("inline_data") is not None:
+            image = Image.open(BytesIO(item["inline_data"]["data"]))
+            mime = item["inline_data"]["mime_type"]
+            ext = mime.split("/")[-1]
+            save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
+            image.save(save_path)
+            media.append({"photo": save_path})
+    return {"texts": beautify_llm_response(texts, newline_level=2), "media": media}
+
+
+def openai_context_to_gemini(context: dict) -> ContentUnionDict:
+    r"""Convert OpenAI context to Gemini format.
+
+    Args:
+        context (dict): {
+                "role": role,  # assistant or user
+                "content": [
+                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+                    ]
+                }
+
+    Returns:
+        dict: {
+            "role": role,  # model or user
+            "parts: [
+                {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
+                {"text": "hello"}
+            ]
+        }
+    """
+    parts: list[Part] = []
+    role = "model" if context["role"] == "assistant" else "user"
+    for item in context["content"]:
+        if item["type"] == "text":
+            parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+        elif item["type"] == "image_url":
+            data = item["image_url"]["url"].split(";base64,")
+            mime = data[0].removeprefix("data:")
+            parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+
+    return {"role": role, "parts": parts}  # type: ignore
+
+
+def gemini_logging(contexts: list):
+    msg = ""
+    with contextlib.suppress(Exception):
+        for item in contexts:
+            role = item.get("role", "").upper() or "MODEL"
+
+            # Request
+            for part in item.get("parts", []):
+                if part.inline_data:
+                    msg += f"[{role}]: Blob_Data  "
+                if part.text:
+                    msg += f"[{role}]: {part.text}  "
+            # Response
+            if item.get("text", ""):
+                msg += f"[{role}]: {item['text']}  "
+            if item.get("inline_data", ""):
+                msg += f"[{role}]: Blob_Data  "
+
+        logger.debug(f"{msg!r}")
src/llm/contexts.py
@@ -69,7 +69,7 @@ async def single_context(client: Client, message: Message) -> dict:
     def clean_text(text: str) -> str:
         if not text:
             return ""
-        for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
+        for prefix in [PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
             text = text.removeprefix(prefix).strip()
         # remove bot tips
         text = re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
src/llm/gpt.py
@@ -5,7 +5,11 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import GPT, PREFIX, TEXT_LENGTH, cache
+from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
+from llm.aigc import HELP as AIGC_HELP
+from llm.aigc import aigc
+
+# from llm.aigc import HELP as AIGC_HELP
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_context_type, get_gpt_config, parse_force_model
 from llm.response import send_to_gpt
@@ -18,10 +22,9 @@ from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix, startswith_prefix
 
 HELP = f"""🤖**GPT对话**
-使用说明:
-1. `{PREFIX.GPT}` 后接提示词即可与GPT对话
-2. 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
-3. 暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
+`{PREFIX.GPT}` 后接提示词即可与GPT对话
+以 `{PREFIX.GPT}` 回复消息可将其加入上下文
+暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
 
 ⚙️模型配置:
 `{PREFIX.GPT}`默认模型: **{GPT.TEXT_MODEL_NAME}**
@@ -45,7 +48,7 @@ def is_gpt_conversation(message: Message) -> bool:
     info = parse_msg(message)
     if info["is_bot"]:  # do not process bot message
         return False
-    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao", "/grok"]):
+    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, PREFIX.GENIMG, "/gpt", "/gemini", "/ds", "/qwen", "/doubao", "/grok"]):
         return True
     # is replying to gpt-bot response message?
     if not message.reply_to_message:
@@ -62,7 +65,8 @@ def is_gpt_conversation(message: Message) -> bool:
         GPT.TEXT_MODEL_NAME,
         GPT.IMAGE_MODEL_NAME,
     ]
-    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
+    aigc_names = [AIGC.IMG_MODEL_NAME]
+    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in aigc_names])
 
 
 async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -79,15 +83,17 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
+    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) and not message.reply_to_message:
+        await send2tg(client, message, texts=AIGC_HELP, **kwargs)
+        return
     if not is_gpt_conversation(message):
         return
-
     reply_text = ""
     if message.reply_to_message:
         reply_info = parse_msg(message.reply_to_message, silent=True)
         reply_text = reply_info["text"]
 
-    force_model = parse_force_model(info["text"], reply_text)
+    force_model, modality = parse_force_model(info["text"], reply_text)
 
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
@@ -98,6 +104,8 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)
     contexts = await get_conversation_contexts(client, conversations)
+    if equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) or modality != "text":
+        return await aigc(client, message, contexts, modality, **kwargs)
     config = get_gpt_config(context_type["type"], contexts, force_model)
     if not config["client"]["api_key"]:
         logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
src/llm/models.py
@@ -4,7 +4,7 @@
 from openai import DefaultAsyncHttpxClient
 from pyrogram.types import Message
 
-from config import GPT, PREFIX, PROXY
+from config import AIGC, GPT, PREFIX, PROXY
 from messages.parser import parse_msg
 from messages.utils import startswith_prefix
 
@@ -23,16 +23,17 @@ def get_context_type(conversations: list[Message]) -> dict:
         if info["mtype"] == "audio":
             has_audio = True
     if has_audio or has_video:
-        res["error"] = f"⚠️已忽略上下文中的视频/音频消息\n可以先用 `{PREFIX.ASR}` 命令转为文字后再使用 `{PREFIX.GPT}`"
+        res["error"] = f"⚠️已忽略上下文中的视频/音频消息\n可以先用 `{PREFIX.ASR}` 命令转为文字后再使用AI功能"
     return res
 
 
-def parse_force_model(text: str, reply_text: str) -> str:
+def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
     """Parse the force model from the text or reply text.
 
     /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao, /grok = Grok
     """
     force_model = ""
+    modality = "text"
     # parse from bot reply
     if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
         force_model = GPT.OPENAI_MODEL
@@ -46,6 +47,9 @@ def parse_force_model(text: str, reply_text: str) -> str:
         force_model = GPT.DOUBAO_MODEL
     elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
         force_model = GPT.GROK_MODEL
+    elif reply_text.startswith(f"🌠{AIGC.IMG_MODEL_NAME}"):
+        force_model = AIGC.IMG_MODEL
+        modality = "image"
     # parse from command prefix
     if startswith_prefix(text, prefix=["/gpt"]):
         force_model = GPT.OPENAI_MODEL
@@ -59,7 +63,10 @@ def parse_force_model(text: str, reply_text: str) -> str:
         force_model = GPT.DOUBAO_MODEL
     elif startswith_prefix(text, prefix=["/grok"]):
         force_model = GPT.GROK_MODEL
-    return force_model
+    elif startswith_prefix(text, prefix=[PREFIX.GENIMG]):
+        force_model = AIGC.IMG_MODEL
+        modality = "image"
+    return force_model, modality
 
 
 def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "") -> dict:
src/llm/utils.py
@@ -87,7 +87,7 @@ def beautify_model_name(name: str) -> str:
     return name.replace("gpt", "GPT").replace("gemini", "Gemini").replace("deepseek", "DeepSeek")  # GPT-4o
 
 
-def beautify_llm_response(text: str) -> str:
+def beautify_llm_response(text: str, newline_level: int = 3) -> str:
     """Beautify LLM response.
 
     Args:
@@ -97,7 +97,19 @@ def beautify_llm_response(text: str) -> str:
     """
     if not text:
         return text
-    # remove tags. should align with the tags in `contexts.py`
+    clean_text = clean_source_marks(text)
+    clean_text = remove_pound(clean_text)
+    clean_text = remove_dash(clean_text)
+    return remove_consecutive_newlines(clean_text, newline_level)
+
+
+def clean_source_marks(text: str) -> str:
+    """Remove [username], [message], ... marks.
+
+    Should align with the tags in `contexts.py`
+    """
+    if not text:
+        return text
     clean_text = ""
     for line in text.split("\n"):
         if line.strip().startswith(("[username]:", "[filename]:")):
@@ -105,10 +117,7 @@ def beautify_llm_response(text: str) -> str:
         if line.strip() in ["[message]:", "[file content]:"]:
             continue
         clean_text += line + "\n"
-    clean_text = clean_text.removesuffix("\n")  # remove the last newline
-    clean_text = remove_pound(clean_text)
-    clean_text = remove_dash(clean_text)
-    return remove_consecutive_newlines(clean_text)
+    return clean_text.removesuffix("\n")  # remove the last newline
 
 
 def extract_reasoning(text: str) -> tuple[str, str]:
src/config.py
@@ -34,7 +34,6 @@ GOOGLE_SEARCH_GL = os.getenv("GOOGLE_SEARCH_GL", "cn")  # "gl" parameter (Geoloc
 
 
 class ENABLE:  # see fine-grained permission in `src/permission.py`
-    AI_SUMMARY = os.getenv("ENABLE_AI_SUMMARY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     ASR = os.getenv("ENABLE_ASR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     AUDIO = os.getenv("ENABLE_AUDIO", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     CRONTAB = os.getenv("ENABLE_CRONTAB", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -78,6 +77,7 @@ class PREFIX:
     VOICE = os.getenv("PREFIX_VOICE", "/voice").lower()
     SEARCH_YOUTUBE = os.getenv("PREFIX_SEARCH_YOUTUBE", "/ytb").lower()
     SEARCH_GOOGLE = os.getenv("PREFIX_SEARCH_GOOGLE", "/google").lower()
+    GENIMG = os.getenv("PREFIX_GENIMG", "/gen").lower()
 
 
 class API:
@@ -142,6 +142,7 @@ class COOKIE:  # See: https://github.com/easychen/CookieCloud
 
 
 class GPT:  # see `llm/README.md`
+    # See class AIGC for the AIGC configurations
     STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
     IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
@@ -263,3 +264,13 @@ class ASR:
     TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None)  # Banned oversea IP, need a back to China proxy
     TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
     TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
+
+
+class AIGC:
+    # https://ai.google.dev/gemini-api/docs/image-generation
+    IMG_BASR_URL = os.getenv("AIGC_IMG_BASR_URL", "https://generativelanguage.googleapis.com/")
+    IMG_API_KEY = os.getenv("AIGC_IMG_API_KEY", "")  # comma separated keys for load balance. e.g. "key1,key2,key3"
+    IMG_MODEL = os.getenv("AIGC_IMG_MODEL", "gemini-2.0-flash-exp")
+    IMG_MODEL_NAME = os.getenv("AIGC_IMG_MODEL_NAME", "Gemini-2.0-Flash")
+    IMG_PROXY = os.getenv("AIGC_IMG_PROXY", None)
+    IMG_MAX_PROMPT_TOKEN = int(os.getenv("AIGC_IMG_MAX_PROMPT_TOKEN", "480"))
src/handler.py
@@ -43,13 +43,13 @@ async def handle_utilities(
     ai: bool = True,
     asr: bool = True,
     audio: bool = True,
-    ytb: bool = True,
     google: bool = True,
-    subtitle: bool = True,
-    wget: bool = True,
     ocr: bool = True,
     price: bool = True,
+    subtitle: bool = True,
     summary: bool = True,
+    wget: bool = True,
+    ytb: bool = True,
     raw_img: bool = True,
     show_progress: bool = True,
     detail_progress: bool = False,
@@ -149,7 +149,32 @@ async def handle_social_media(
         cmd_prefix.extend(PREFIX.MAIN)
     ignore_prefix = ignore_prefix or ["/dl4dw"]
     # these commands are handled in `handle_utilities`
-    ignore_prefix.extend(["/ai", "/asr", "/audio", "/combine", "/doubao", "/ds", "/gemini", "/gpt", "/ocr", "/price", "/qwen", "/grok", "/subtitle", "/summary", "/voice", "/wget"])
+    ignore_prefix.extend(
+        [
+            "/doubao",
+            "/ds",
+            "/gemini",
+            "/gpt",
+            "/grok",
+            "/qwen",
+            PREFIX.ASR,
+            PREFIX.AI_SUMMARY,
+            PREFIX.AUDIO,
+            PREFIX.COMBINATION,
+            PREFIX.CONVERT,
+            PREFIX.CRYPTO,
+            PREFIX.GENIMG,
+            PREFIX.GPT,
+            PREFIX.OCR,
+            PREFIX.PRICE,
+            PREFIX.SEARCH_GOOGLE,
+            PREFIX.SEARCH_YOUTUBE,
+            PREFIX.STOCK,
+            PREFIX.SUBTITLE,
+            PREFIX.VOICE,
+            PREFIX.WGET,
+        ]
+    )
     info = parse_msg(message)
     this_texts = info["text"]  # texts of the trigger message
     if startswith_prefix(this_texts, prefix=ignore_prefix):
@@ -279,7 +304,9 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
         msg += "\n🅱️哔哩哔哩"
         msg += "\n🆕和所有yt-dlp支持的链接\n"
     if permission["ai"]:
-        msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok` + 提示词"
+        msg += f"\n🤖**AI对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok`"
+        msg += f"\n📖**AI总结**: `{PREFIX.AI_SUMMARY}` 总结历史聊天记录"
+        msg += f"\n🌠**AIGC**: `{PREFIX.GENIMG}`"
     if permission["asr"]:
         msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
     if permission["audio"]:
@@ -290,8 +317,6 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
         msg += f"\n💵**查询价格**: `{PREFIX.PRICE}` + Symbol"
     if permission["subtitle"]:
         msg += f"\n📃**提取字幕**: `{PREFIX.SUBTITLE}` + 油管链接 (或回复油管链接)"
-    if permission["summary"] and permission["ai"]:  # summary depends on ai
-        msg += f"\n🤖**总结历史**: `{PREFIX.AI_SUMMARY}` AI总结历史聊天记录"
     if permission["wget"]:
         msg += f"\n⏬**下载文件**: `{PREFIX.WGET}` + URL"
     if permission["ytb"]:
src/permission.py
@@ -102,7 +102,6 @@ def check_service(cid: int | str, ctype: str) -> dict:
         "ocr": True,
         "price": True,
         "raw_img": True,
-        "summary": True,
         "ytb": True,
         "google": True,
         "show_progress": True,
@@ -153,8 +152,6 @@ def check_service(cid: int | str, ctype: str) -> dict:
         permission["ocr"] = False
     if not ENABLE.PRICE:
         permission["price"] = False
-    if not ENABLE.AI_SUMMARY:
-        permission["summary"] = False
     if not ENABLE.RAW_IMG_CONVERT:
         permission["raw_img"] = False
 
src/utils.py
@@ -259,11 +259,14 @@ def remove_pound(text: str) -> str:
     return text
 
 
-def remove_consecutive_newlines(text: str) -> str:
+def remove_consecutive_newlines(text: str, newline_level: int = 3) -> str:
     if not text:
         return ""
     while "\n\n\n" in text:
         text = text.replace("\n\n\n", "\n\n")
+    if newline_level == 2:
+        while "\n\n" in text:
+            text = text.replace("\n\n", "\n")
     return text