Commit cc41967
src/llm/aigc.py
@@ -1,171 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-import contextlib
-import random
-from io import BytesIO
-from pathlib import Path
-
-from glom import glom
-from google import genai
-from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
-from loguru import logger
-from PIL import Image
-from pyrogram.client import Client
-from pyrogram.types import Message
-
-from config import AIGC, DOWNLOAD_DIR, PREFIX
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
-from messages.parser import parse_msg
-from messages.progress import modify_progress
-from messages.sender import send2tg
-from utils import number_to_emoji, rand_string
-
-HELP = f"""🌠**AIGC**
-`{PREFIX.GENIMG}` 后接提示词即可生成
-回复消息可继续对话重新修改生成结果
-
-⚙️模型配置:
-🏞生图模型: **{AIGC.IMG_MODEL}
-
-⚠️目前只支持生成图片
-"""
-
-
-async def aigc(client: Client, message: Message, contexts: list[dict], modality: str = "image", **kwargs):
- r"""Get AIGC response.
-
- contexts: [
- {
- "role": role, # assistant or user
- "content": [
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
- {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
- ]
- }
- ]
-
- Args:
- client (Client): The Pyrogram client.
- message (Message): The trigger message object.
- contexts (list[dict]): Parsed from chat history.
- modality (str): response modality
- """
- # ruff: noqa: RET502, RET503
- info = parse_msg(message)
- api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
- random.choice(api_keys)
- response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
- tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
- res = {}
- try:
- app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
- count_tokens = await app.aio.models.count_tokens(model=AIGC.IMG_MODEL, contents=info["text"])
- num_token = count_tokens.total_tokens or 0
- if num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
- await send2tg(client, message, texts=f"当前提示词过长: {num_token} Tokens\n提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}", **kwargs)
- return
-
- msg = f"🌠**{AIGC.IMG_MODEL_NAME}**: 思考中...\n{clean_prefix(info['text'])}"
- status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
- kwargs["progress"] = status_msg
- gemini_contexts = [openai_context_to_gemini(context) for context in contexts]
- gemini_logging(gemini_contexts)
- response = await app.aio.models.generate_content(
- model=AIGC.IMG_MODEL,
- contents=gemini_contexts,
- config=GenerateContentConfig(
- response_modalities=response_modalities,
- tools=tools, # type: ignore
- ),
- )
- res = parse_response(glom(response.model_dump(), "candidates.0"), model_name=AIGC.IMG_MODEL_NAME)
- except Exception as e:
- logger.error(e)
- error = str(e)
- if "res" in locals():
- error += f"\n{res}"
- if "response" in locals():
- error += f"\n{response}"
- return await modify_progress(text=error, force_update=True, **kwargs)
- await send2tg(client, message, caption_above=True, **res, **kwargs)
- await modify_progress(del_status=True, **kwargs)
-
-
-def parse_response(data: dict, model_name: str) -> dict:
- parts = glom(data, "content.parts", default=[]) or []
- gemini_logging(parts)
- grounding_chunks = glom(data, "grounding_metadata.grounding_chunks", default=[]) or []
- texts = ""
- prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
- media = []
- for item in parts:
- if item.get("text") is not None:
- texts += item["text"]
- if item.get("inline_data") is not None:
- image = Image.open(BytesIO(item["inline_data"]["data"]))
- mime = item["inline_data"]["mime_type"]
- ext = mime.split("/")[-1]
- save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
- image.save(save_path)
- media.append({"photo": save_path})
- for idx, grounding in enumerate(grounding_chunks):
- title = glom(grounding, "web.title", default="Web")
- url = glom(grounding, "web.uri", default="https://www.google.com")
- texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
- return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
-
-
-def openai_context_to_gemini(context: dict) -> ContentUnionDict:
- r"""Convert OpenAI context to Gemini format.
-
- Args:
- context (dict): {
- "role": role, # assistant or user
- "content": [
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
- {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
- ]
- }
-
- Returns:
- dict: {
- "role": role, # model or user
- "parts: [
- {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
- {"text": "hello"}
- ]
- }
- """
- parts: list[Part] = []
- role = "model" if context["role"] == "assistant" else "user"
- for item in context["content"]:
- if item["type"] == "text":
- parts.append(Part.from_text(text=clean_source_marks(item["text"])))
- elif item["type"] == "image_url":
- data = item["image_url"]["url"].split(";base64,")
- mime = data[0].removeprefix("data:")
- parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
-
- return {"role": role, "parts": parts} # type: ignore
-
-
-def gemini_logging(contexts: list):
- msg = ""
- with contextlib.suppress(Exception):
- for item in contexts:
- role = item.get("role", "").upper() or "MODEL"
-
- # Request
- for part in item.get("parts", []):
- if part.inline_data:
- msg += f"[{role}]: Blob_Data "
- if part.text:
- msg += f"[{role}]: {part.text} "
- # Response
- if item.get("text", ""):
- msg += f"[{role}]: {item['text']} "
- if item.get("inline_data", ""):
- msg += f"[{role}]: Blob_Data "
-
- logger.debug(f"{msg!r}")
src/llm/gemini.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import contextlib
+import random
+from io import BytesIO
+from pathlib import Path
+
+from glom import glom
+from google import genai
+from google.genai.types import ContentUnionDict, GenerateContentConfig, GoogleSearch, HttpOptions, Part, Tool
+from loguru import logger
+from PIL import Image
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import AIGC, DOWNLOAD_DIR, PREFIX, TEXT_LENGTH
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import count_without_entities, smart_split
+from utils import number_to_emoji, rand_string
+
+HELP = f"""🌠**AIGC**
+`{PREFIX.GENIMG}` 后接提示词即可生成
+回复消息可继续对话重新修改生成结果
+
+⚙️模型配置:
+🏞生图模型: **{AIGC.IMG_MODEL}
+
+⚠️目前只支持生成图片
+"""
+
+
+async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], model: str = "", model_name: str = "", modality: str = "image", **kwargs):
+ r"""Get Gemini response.
+
+ gpt_contexts: [
+ {
+ "role": role, # assistant or user
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+ {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+ ]
+ }
+ ]
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ gpt_contexts (list[dict]): OpenAI context format parsed from chat history.
+ model (str): model id.
+ model_name (str): friendly model name
+ modality (str): response modality
+ """
+ # ruff: noqa: RET502, RET503
+ info = parse_msg(message)
+ api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
+ tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
+ keep_marks = modality == "text" # keep source marks for text response
+
+ try:
+ app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+ count_tokens = await app.aio.models.count_tokens(model=model, contents=info["text"])
+ num_token = count_tokens.total_tokens or 0
+ if modality == "image" and num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
+ await send2tg(client, message, texts=f"生成{modality.upper()}时提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}\n当前提示词: {num_token} Tokens", **kwargs)
+ return
+ msg = f"🌠**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
+ status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+ kwargs["progress"] = status_msg
+ contexts = [openai_context_to_gemini(context, keep_marks=keep_marks) for context in gpt_contexts]
+ gemini_logging(contexts)
+ if modality == "image":
+ return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+ return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+ except Exception as e:
+ logger.error(e)
+
+
+def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> ContentUnionDict:
+ r"""Convert OpenAI context to Gemini format.
+
+ Args:
+ context (dict): {
+ "role": role, # assistant or user
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+ {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+ ]
+ }
+
+ Returns:
+ dict: {
+ "role": role, # model or user
+ "parts: [
+ {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
+ {"text": "hello"}
+ ]
+ }
+ """
+ parts: list[Part] = []
+ role = "model" if context["role"] == "assistant" else "user"
+ for item in context["content"]:
+ if item["type"] == "text":
+ if keep_marks:
+ parts.append(Part.from_text(text=item["text"]))
+ else:
+ parts.append(Part.from_text(text=clean_source_marks(item["text"])))
+ elif item["type"] == "image_url":
+ data = item["image_url"]["url"].split(";base64,")
+ mime = data[0].removeprefix("data:")
+ parts.append(Part.from_bytes(mime_type=mime, data=data[1]))
+
+ return {"role": role, "parts": parts} # type: ignore
+
+
+def gemini_logging(contexts: list):
+ msg = ""
+ with contextlib.suppress(Exception):
+ for item in contexts:
+ role = item.get("role", "").upper() or "MODEL"
+
+ # Request
+ for part in item.get("parts", []):
+ if part.inline_data:
+ msg += f"[{role}]: Blob_Data "
+ if part.text:
+ msg += f"[{role}]: {part.text} "
+ # Response
+ if item.get("text", ""):
+ msg += f"[{role}]: {item['text']} "
+ if item.get("inline_data", ""):
+ msg += f"[{role}]: Blob_Data "
+
+ logger.debug(f"{msg!r}")
+
+
+async def gemini_nonstream(
+ client: Client,
+ message: Message,
+ contexts: list[ContentUnionDict],
+ model: str,
+ model_name: str,
+ response_modalities: list[str],
+ tools: list | None = None,
+ retry: int = 0,
+ **kwargs,
+):
+ try:
+ api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ if retry > len(api_keys) - 1:
+ return
+ app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+ response = await app.aio.models.generate_content(
+ model=model,
+ contents=contexts,
+ config=GenerateContentConfig(
+ response_modalities=response_modalities,
+ tools=tools,
+ ),
+ )
+ prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+ res = parse_response(response.model_dump(), prefix=prefix)
+ await send2tg(client, message, caption_above=True, **res, **kwargs)
+ await modify_progress(del_status=True, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ error = str(e)
+ if "res" in locals():
+ error += f"\n{res}"
+ if "response" in locals():
+ error += f"\n{response}"
+ await modify_progress(text=error, force_update=True, **kwargs)
+ return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs) # type: ignore
+
+
+def parse_response(data: dict, prefix: str = "") -> dict:
+ logger.trace(data)
+ parts = glom(data, "candidates.0.content.parts", default=[]) or []
+ gemini_logging(parts)
+ grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
+ texts = ""
+ media = []
+ for item in parts:
+ if item.get("text") is not None:
+ texts += item["text"]
+ if item.get("inline_data") is not None:
+ image = Image.open(BytesIO(item["inline_data"]["data"]))
+ mime = item["inline_data"]["mime_type"]
+ ext = mime.split("/")[-1]
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
+ image.save(save_path)
+ media.append({"photo": save_path})
+ for idx, grounding in enumerate(grounding_chunks):
+ if idx > 9:
+ break
+ title = glom(grounding, "web.title", default="Web")
+ url = glom(grounding, "web.uri", default="https://www.google.com")
+ texts += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+ return {"texts": prefix + beautify_llm_response(texts, newline_level=2), "media": media}
+
+
+async def gemini_stream(
+ client: Client,
+ message: Message,
+ contexts: list[ContentUnionDict],
+ model: str,
+ model_name: str,
+ response_modalities: list[str],
+ tools: list | None = None,
+ retry: int = 0,
+ **kwargs,
+):
+ prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+ answers = prefix
+ try:
+ status = kwargs.get("progress")
+ api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ if retry > len(api_keys) - 1:
+ return
+ app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
+ async for chunk in await app.aio.models.generate_content_stream(
+ model=model,
+ contents=contexts,
+ config=GenerateContentConfig(response_modalities=response_modalities, tools=tools),
+ ):
+ resp = parse_response(chunk.model_dump())
+ answer = resp.get("texts", "")
+ answers += answer
+ answers = beautify_llm_response(answers)
+ if await count_without_entities(answers) <= TEXT_LENGTH:
+ if len(answers.removeprefix(prefix)) > 3: # start response if answer is not empty
+ await modify_progress(message=status, text=answers, detail_progress=True)
+ else: # answers is too long, split it into multiple messages
+ parts = await smart_split(answers)
+ await modify_progress(message=status, text=parts[0], force_update=True) # force send the first part
+ answers = parts[-1] # keep the last part
+ status = await client.send_message(message.chat.id, answers) # the new message
+
+ # all chunks are processed
+ await modify_progress(message=status, text=beautify_llm_response(answers), force_update=True)
+
+ except Exception as e:
+ logger.error(e)
+ error = str(e)
+ if "resp" in locals():
+ error += f"\n{resp}"
+ await modify_progress(text=error, force_update=True, **kwargs)
+ return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs) # type: ignore
src/llm/gpt.py
@@ -6,11 +6,9 @@ from pyrogram.client import Client
from pyrogram.types import Message
from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
-from llm.aigc import HELP as AIGC_HELP
-from llm.aigc import aigc
-
-# from llm.aigc import HELP as AIGC_HELP
from llm.contexts import get_conversation_contexts, get_conversations
+from llm.gemini import HELP as AIGC_HELP
+from llm.gemini import gemini_response
from llm.models import get_context_type, get_gpt_config, parse_force_model
from llm.response import send_to_gpt
from llm.response_stream import send_to_gpt_stream
@@ -64,9 +62,9 @@ def is_gpt_conversation(message: Message) -> bool:
GPT.GROK_MODEL_NAME,
GPT.TEXT_MODEL_NAME,
GPT.IMAGE_MODEL_NAME,
+ AIGC.IMG_MODEL_NAME,
]
- aigc_names = [AIGC.IMG_MODEL_NAME]
- return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in aigc_names])
+ return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in model_names])
async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -104,12 +102,14 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
conversations = get_conversations(message)
context_type = get_context_type(conversations)
contexts = await get_conversation_contexts(client, conversations)
- if equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) or modality != "text":
- return await aigc(client, message, contexts, modality, **kwargs)
config = get_gpt_config(context_type["type"], contexts, force_model)
if not config["client"]["api_key"]:
logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
+ if "gemini" in config["completions"]["model"].lower():
+ model_name = AIGC.IMG_MODEL_NAME if startswith_prefix(info["text"], prefix=[PREFIX.GENIMG]) else config["friendly_name"]
+ return await gemini_response(client, message, contexts, config["completions"]["model"], model_name, modality, **kwargs)
+
msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg