Commit cc41967

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-28 08:52:21
refactor(gpt): use official SDK for Gemini
1 parent 6629619
Changed files (3)
src/llm/aigc.py
@@ -1,171 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-import contextlib
-import random
-from io import BytesIO
-from pathlib import Path
-
-from glom import glom
-from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
-from loguru import logger
-from PIL import Image
-from pyrogram.client import Client
-from pyrogram.types import Message
-
-from config import AIGC, DOWNLOAD_DIR, PREFIX
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
-from messages.parser import parse_msg
-from messages.progress import modify_progress
-from messages.sender import send2tg
-from utils import number_to_emoji, rand_string
-
-HELP = f"""🌠**AIGC**
-`{PREFIX.GENIMG}` 后接提示词即可生成
-回复消息可继续对话重新修改生成结果
-
-⚙️模型配置:
-🏞生图模型: **{AIGC.IMG_MODEL}
-
-⚠️目前只支持生成图片
-"""
-
-
-async def aigc(client: Client, message: Message, contexts: list[dict], modality: str = "image", **kwargs):
-    r"""Get AIGC response.
-
-    contexts: [
-                {
-                "role": role,  # assistant or user
-                "content": [
-                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
-                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
-                    ]
-                }
-            ]
-
-    Args:
-        client (Client): The Pyrogram client.
-        message (Message): The trigger message object.
-        contexts (list[dict]): Parsed from chat history.
-        modality (str): response modality
-    """
-    # ruff: noqa: RET502, RET503
-    info = parse_msg(message)
-    api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
-    random.choice(api_keys)
-    response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
-    tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
-    res = {}
-    try:
-        app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
-        count_tokens = await app.aio.models.count_tokens(model=AIGC.IMG_MODEL, contents=info["text"])
-        num_token = count_tokens.total_tokens or 0
-        if num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
-            await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
-            return
-
-        msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 思考中...\n{clean_prefix(info['text'])}"
-        status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
-        kwargs["progress"] = status_msg
-        gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
-        gemini_logging(gemini_contexts)
-        response = await app.aio.models.generate_content(
-            model=AIGC.IMG_MODEL,
-            contents=gemini_contexts,
-            config=GenerateContentConfig(
-                response_modalities=response_modalities,
-                tools=tools,  # type: ignore
-            ),
-        )
-        res = parse_response(glom(response.model_dump(), "candidates.0"), model_name=AIGC.IMG_MODEL_NAME)
-    except Exception as e:
-        logger.error(e)
-        error = str(e)
-        if "res" in locals():
-            error += f"\n{res}"
-        if "response" in locals():
-            error += f"\n{response}"
-        return await modify_progress(text=error, force_update=True, **kwargs)
-    await send2tg(client, message, caption_above=True, **res, **kwargs)
-    await modify_progress(del_status=True, **kwargs)
-
-
-def parse_response(data: dict, model_name: str) -> dict:
-    parts = glom(data, "content.parts", default=[]) or []
-    gemini_logging(parts)
-    grounding_chunks = glom(data, "grounding_metadata.grounding_chunks", default=[]) or []
-    texts = ""
-    prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
-    media = []
-    for item in parts:
-        if item.get("text") is not None:
-            texts += item["text"]
-        if item.get("inline_data") is not None:
-            image = Image.open(BytesIO(item["inline_data"]["data"]))
-            mime = item["inline_data"]["mime_type"]
-            ext = mime.split("/")[-1]
-            save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
-            image.save(save_path)
-            media.append({"photo": save_path})
-    for idx, grounding in enumerate(grounding_chunks):
-        title = glom(grounding, "web.title", default="Web")
-        url = glom(grounding, "web.uri", default="https://www.google.com")
-        texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
-    return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
-
-
-def openai_context_to_gemini(context: dict) -> ContentUnionDict:
-    r"""Convert OpenAI context to Gemini format.
-
-    Args:
-        context (dict): {
-                "role": role,  # assistant or user
-                "content": [
-                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
-                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
-                    ]
-                }
-
-    Returns:
-        dict: {
-            "role": role,  # model or user
-            "parts: [
-                {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
-                {"text": "hello"}
-            ]
-        }
-    """
-    parts: list[Part] = []
-    role = "model" if context["role"] == "assistant" else "user"
-    for item in context["content"]:
-        if item["type"] == "text":
-            parts.append(Part.from_text(text=clean_source_marks(item["text"])))
-        elif item["type"] == "image_url":
-            data = item["image_url"]["url"].split(";base64,")
-            mime = data[0].removeprefix("data:")
-            parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
-
-    return {"role": role, "parts": parts}  # type: ignore
-
-
-def gemini_logging(contexts: list):
-    msg = ""
-    with contextlib.suppress(Exception):
-        for item in contexts:
-            role = item.get("role", "").upper() or "MODEL"
-
-            # Request
-            for part in item.get("parts", []):
-                if part.inline_data:
-                    msg += f"[{role}]: Blob_Data  "
-                if part.text:
-                    msg += f"[{role}]: {part.text}  "
-            # Response
-            if item.get("text", ""):
-                msg += f"[{role}]: {item['text']}  "
-            if item.get("inline_data", ""):
-                msg += f"[{role}]: Blob_Data  "
-
-        logger.debug(f"{msg!r}")
src/llm/gemini.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import contextlib
+import random
+from io import BytesIO
+from pathlib import Path
+
+from glom import glom
+from google import genai
+from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
+from loguru import logger
+from PIL import Image
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import AIGC, DOWNLOAD_DIR, PREFIX, TEXT_LENGTH
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import count_without_entities, smart_split
+from utils import number_to_emoji, rand_string
+
+HELP = f"""🌠**AIGC**
+`{PREFIX.GENIMG}` 后接提示词即可生成
+回复消息可继续对话重新修改生成结果
+
+⚙️模型配置:
+🏞生图模型: **{AIGC.IMG_MODEL}
+
+⚠️目前只支持生成图片
+"""
+
+
+async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], model: str = "", model_name: str = "", modality: str = "image", **kwargs):
+    r"""Get Gemini response.
+
+    gpt_contexts: [
+                {
+                "role": role,  # assistant or user
+                "content": [
+                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+                    ]
+                }
+            ]
+
+    Args:
+        client (Client): The Pyrogram client.
+        message (Message): The trigger message object.
+        gpt_contexts (list[dict]): OpenAI context format parsed from chat history.
+        model (str): model id.
+        model_name (str): friendly model name
+        modality (str): response modality
+    """
+    # ruff: noqa: RET502, RET503
+    info = parse_msg(message)
+    api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+    response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+    tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
+    keep_marks = modality == "text"  # keep source marks for text response
+
+    try:
+        app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+        count_tokens = await app.aio.models.count_tokens(model=model, contents=info["text"])
+        num_token = count_tokens.total_tokens or 0
+        if modality == "image" and num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
+            await send2tg(client, message, texts=f"生成{modality.upper()}时提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}\n当前提示词: {num_token} Tokens", **kwargs)
+            return
+        msg = f"🌠**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
+        status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+        kwargs["progress"] = status_msg
+        contexts = [openai_context_to_gemini(context, keep_marks=keep_marks) for context in gpt_contexts]
+        gemini_logging(contexts)
+        if modality == "image":
+            return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+        return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+    except Exception as e:
+        logger.error(e)
+
+
+def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> ContentUnionDict:
+    r"""Convert OpenAI context to Gemini format.
+
+    Args:
+        context (dict): {
+                "role": role,  # assistant or user
+                "content": [
+                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+                        {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+                    ]
+                }
+
+    Returns:
+        dict: {
+            "role": role,  # model or user
+            "parts: [
+                {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
+                {"text": "hello"}
+            ]
+        }
+    """
+    parts: list[Part] = []
+    role = "model" if context["role"] == "assistant" else "user"
+    for item in context["content"]:
+        if item["type"] == "text":
+            if keep_marks:
+                parts.append(Part.from_text(text=item["text"]))
+            else:
+                parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+        elif item["type"] == "image_url":
+            data = item["image_url"]["url"].split(";base64,")
+            mime = data[0].removeprefix("data:")
+            parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+
+    return {"role": role, "parts": parts}  # type: ignore
+
+
+def gemini_logging(contexts: list):
+    msg = ""
+    with contextlib.suppress(Exception):
+        for item in contexts:
+            role = item.get("role", "").upper() or "MODEL"
+
+            # Request
+            for part in item.get("parts", []):
+                if part.inline_data:
+                    msg += f"[{role}]: Blob_Data  "
+                if part.text:
+                    msg += f"[{role}]: {part.text}  "
+            # Response
+            if item.get("text", ""):
+                msg += f"[{role}]: {item['text']}  "
+            if item.get("inline_data", ""):
+                msg += f"[{role}]: Blob_Data  "
+
+        logger.debug(f"{msg!r}")
+
+
+async def gemini_nonstream(
+    client: Client,
+    message: Message,
+    contexts: list[ContentUnionDict],
+    model: str,
+    model_name: str,
+    response_modalities: list[str],
+    tools: list | None = None,
+    retry: int = 0,
+    **kwargs,
+):
+    try:
+        api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+        if retry > len(api_keys) - 1:
+            return
+        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+        response = await app.aio.models.generate_content(
+            model=model,
+            contents=contexts,
+            config=GenerateContentConfig(
+                response_modalities=response_modalities,
+                tools=tools,
+            ),
+        )
+        prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+        res = parse_response(response.model_dump(), prefix=prefix)
+        await send2tg(client, message, caption_above=True, **res, **kwargs)
+        await modify_progress(del_status=True, **kwargs)
+    except Exception as e:
+        logger.error(e)
+        error = str(e)
+        if "res" in locals():
+            error += f"\n{res}"
+        if "response" in locals():
+            error += f"\n{response}"
+        await modify_progress(text=error, force_update=True, **kwargs)
+        return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs)  # type: ignore
+
+
+def parse_response(data: dict, prefix: str = "") -> dict:
+    logger.trace(data)
+    parts = glom(data, "candidates.0.content.parts", default=[]) or []
+    gemini_logging(parts)
+    grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
+    texts = ""
+    media = []
+    for item in parts:
+        if item.get("text") is not None:
+            texts += item["text"]
+        if item.get("inline_data") is not None:
+            image = Image.open(BytesIO(item["inline_data"]["data"]))
+            mime = item["inline_data"]["mime_type"]
+            ext = mime.split("/")[-1]
+            save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
+            image.save(save_path)
+            media.append({"photo": save_path})
+    for idx, grounding in enumerate(grounding_chunks):
+        if idx > 9:
+            break
+        title = glom(grounding, "web.title", default="Web")
+        url = glom(grounding, "web.uri", default="https://www.google.com")
+        texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+    return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
+
+
+async def gemini_stream(
+    client: Client,
+    message: Message,
+    contexts: list[ContentUnionDict],
+    model: str,
+    model_name: str,
+    response_modalities: list[str],
+    tools: list | None = None,
+    retry: int = 0,
+    **kwargs,
+):
+    prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+    answers = prefix
+    try:
+        status = kwargs.get("progress")
+        api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+        if retry > len(api_keys) - 1:
+            return
+        app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+        async for chunk in await app.aio.models.generate_content_stream(
+            model=model,
+            contents=contexts,
+            config=GenerateContentConfig(response_modalities=response_modalities, tools=tools),
+        ):
+            resp = parse_response(chunk.model_dump())
+            answer = resp.get("texts", "")
+            answers += answer
+            answers = beautify_llm_response(answers)
+            if await count_without_entities(answers) <= TEXT_LENGTH:
+                if len(answers.removeprefix(prefix)) > 3:  # start response if answer is not empty
+                    await modify_progress(message=status, text=answers, detail_progress=True)
+            else:  # answers is too long, split it into multiple messages
+                parts = await smart_split(answers)
+                await modify_progress(message=status, text=parts[0], force_update=True)  # force send the first part
+                answers = parts[-1]  # keep the last part
+                status = await client.send_message(message.chat.id, answers)  # the new message
+
+        # all chunks are processed
+        await modify_progress(message=status, text=beautify_llm_response(answers), force_update=True)
+
+    except Exception as e:
+        logger.error(e)
+        error = str(e)
+        if "resp" in locals():
+            error += f"\n{resp}"
+        await modify_progress(text=error, force_update=True, **kwargs)
+        return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs)  # type: ignore
src/llm/gpt.py
@@ -6,11 +6,9 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
-from llm.aigc import HELP as AIGC_HELP
-from llm.aigc import aigc
-
-# from llm.aigc import HELP as AIGC_HELP
 from llm.contexts import get_conversation_contexts, get_conversations
+from llm.gemini import HELP as AIGC_HELP
+from llm.gemini import gemini_response
 from llm.models import get_context_type, get_gpt_config, parse_force_model
 from llm.response import send_to_gpt
 from llm.response_stream import send_to_gpt_stream
@@ -64,9 +62,9 @@ def is_gpt_conversation(message: Message) -> bool:
         GPT.GROK_MODEL_NAME,
         GPT.TEXT_MODEL_NAME,
         GPT.IMAGE_MODEL_NAME,
+        AIGC.IMG_MODEL_NAME,
     ]
-    aigc_names = [AIGC.IMG_MODEL_NAME]
-    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in aigc_names])
+    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in model_names])
 
 
 async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -104,12 +102,14 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)
     contexts = await get_conversation_contexts(client, conversations)
-    if equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) or modality != "text":
-        return await aigc(client, message, contexts, modality, **kwargs)
     config = get_gpt_config(context_type["type"], contexts, force_model)
     if not config["client"]["api_key"]:
         logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
         return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+    if "gemini" in config["completions"]["model"].lower():
+        model_name = AIGC.IMG_MODEL_NAME if startswith_prefix(info["text"], prefix=[PREFIX.GENIMG]) else config["friendly_name"]
+        return await gemini_response(client, message, contexts, config["completions"]["model"], model_name, modality, **kwargs)
+
     msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg