Commit d519f14
src/llm/gemini.py
@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import contextlib
-import random
from io import BytesIO
from pathlib import Path
@@ -14,7 +13,7 @@ from PIL import Image
from pyrogram.client import Client
from pyrogram.types import Message
-from config import AIGC, DOWNLOAD_DIR, PREFIX, TEXT_LENGTH
+from config import DOWNLOAD_DIR, GEMINI, PREFIX, TEXT_LENGTH
from llm.utils import BOT_TIPS, beautify_llm_response, clean_prefix, clean_source_marks
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -22,18 +21,18 @@ from messages.sender import send2tg
from messages.utils import count_without_entities, smart_split
from utils import number_to_emoji, rand_string
-HELP = f"""🌠**AIGC**
+HELP = f"""🌠**AI生图**
`{PREFIX.GENIMG}` 后接提示词即可生成
回复消息可继续对话重新修改生成结果
⚙️模型配置:
-🏞生图模型: **{AIGC.IMG_MODEL}
+🌠生图模型: **{GEMINI.IMG_MODEL}
⚠️目前只支持生成图片
"""
-async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], model: str = "", model_name: str = "", modality: str = "image", **kwargs):
+async def gemini_response(client: Client, message: Message, gpt_contexts: list[dict], modality: str = "image", **kwargs):
r"""Get Gemini response.
gpt_contexts: [
@@ -54,28 +53,31 @@ async def gemini_response(client: Client, message: Message, gpt_contexts: list[d
model_name (str): friendly model name
modality (str): response modality
"""
- # ruff: noqa: RET502, RET503
info = parse_msg(message)
- api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
+ model_name = GEMINI.TEXT_MODEL_NAME if modality == "text" else GEMINI.IMG_MODEL_NAME
response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
tools = [Tool(google_search=GoogleSearch())] if modality == "text" else None
keep_marks = modality == "text" # keep source marks for text response
-
try:
- app = genai.Client(api_key=random.choice(api_keys), http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
- count_tokens = await app.aio.models.count_tokens(model=model, contents=info["text"])
- num_token = count_tokens.total_tokens or 0
- if modality == "image" and num_token > AIGC.IMG_MAX_PROMPT_TOKEN:
- await send2tg(client, message, texts=f"生成{modality.upper()}时提示词Token不得超过: {AIGC.IMG_MAX_PROMPT_TOKEN}\n当前提示词: {num_token} Tokens", **kwargs)
- return
- msg = f"🌠**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
+ msg = f"🤖**{model_name}**: 思考中...\n{clean_prefix(info['text'])}"
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
contexts = [openai_context_to_gemini(context, keep_marks=keep_marks) for context in gpt_contexts]
gemini_logging(contexts)
+ params = {}
+ params |= {"model": model, "contents": contexts}
+ genconfig = {}
+ genconfig |= {"response_modalities": response_modalities}
+ if tools:
+ genconfig |= {"tools": tools}
+ if GEMINI.PREFER_LANG and modality == "text":
+ genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}回复"}
+ params |= {"config": GenerateContentConfig(**genconfig)}
+
if modality == "image":
- return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
- return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, **kwargs)
+ return await gemini_nonstream(client, message, model_name, params, **kwargs)
+ return await gemini_stream(client, message, model_name, params, **kwargs)
except Exception as e:
logger.error(e)
@@ -141,28 +143,19 @@ def gemini_logging(contexts: list):
async def gemini_nonstream(
client: Client,
message: Message,
- contexts: list[ContentUnionDict],
- model: str,
model_name: str,
- response_modalities: list[str],
- tools: list | None = None,
+ params: dict,
retry: int = 0,
**kwargs,
):
try:
- api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
if retry > len(api_keys) - 1:
- return
- app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
- response = await app.aio.models.generate_content(
- model=model,
- contents=contexts,
- config=GenerateContentConfig(
- response_modalities=response_modalities,
- tools=tools,
- ),
- )
- prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+ return None
+ app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY}))
+
+ response = await app.aio.models.generate_content(**params)
+ prefix = f"🤖**{model_name}**: ({BOT_TIPS})\n"
res = parse_response(response.model_dump(), prefix=prefix)
await send2tg(client, message, caption_above=True, **res, **kwargs)
await modify_progress(del_status=True, **kwargs)
@@ -174,7 +167,7 @@ async def gemini_nonstream(
if "response" in locals():
error += f"\n{response}"
await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_nonstream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs) # type: ignore
+ return await gemini_nonstream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
def parse_response(data: dict, prefix: str = "") -> dict:
@@ -206,27 +199,20 @@ def parse_response(data: dict, prefix: str = "") -> dict:
async def gemini_stream(
client: Client,
message: Message,
- contexts: list[ContentUnionDict],
- model: str,
model_name: str,
- response_modalities: list[str],
- tools: list | None = None,
+ params: dict,
retry: int = 0,
**kwargs,
):
- prefix = f"🌠**{model_name}**: ({BOT_TIPS})\n"
+ prefix = f"🤖**{model_name}**: ({BOT_TIPS})\n"
answers = prefix
try:
status = kwargs.get("progress")
- api_keys = [x.strip() for x in AIGC.IMG_API_KEY.split(",") if x.strip()]
+ api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
if retry > len(api_keys) - 1:
- return
- app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=AIGC.IMG_BASR_URL, async_client_args={"proxy": AIGC.IMG_PROXY}))
- async for chunk in await app.aio.models.generate_content_stream(
- model=model,
- contents=contexts,
- config=GenerateContentConfig(response_modalities=response_modalities, tools=tools),
- ):
+ return None
+ app = genai.Client(api_key=api_keys[retry], http_options=HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY}))
+ async for chunk in await app.aio.models.generate_content_stream(**params):
resp = parse_response(chunk.model_dump())
answer = resp.get("texts", "")
answers += answer
@@ -249,4 +235,4 @@ async def gemini_stream(
if "resp" in locals():
error += f"\n{resp}"
await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_stream(client, message, contexts, model, model_name, response_modalities, tools, retry + 1, **kwargs) # type: ignore
+ return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
src/llm/gpt.py
@@ -5,7 +5,7 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import AIGC, GPT, PREFIX, TEXT_LENGTH, cache
+from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, cache
from llm.contexts import get_conversation_contexts, get_conversations
from llm.gemini import HELP as AIGC_HELP
from llm.gemini import gemini_response
@@ -29,7 +29,7 @@ HELP = f"""🤖**GPT对话**
🔄使用以下命令强制切换模型:
`/gpt`: **{GPT.OPENAI_MODEL_NAME}** {image_emoji(GPT.OPENAI_IMAGE_CAPABILITY)}
-`/gemini`: **{GPT.GEMINI_MODEL_NAME}** {image_emoji(GPT.GEMINI_IMAGE_CAPABILITY)}
+`/gemini`: **{GEMINI.TEXT_MODEL_NAME}** {image_emoji(capability=True)}
`/ds`: **{GPT.DEEPSEEK_MODEL_NAME}** {image_emoji(GPT.DEEPSEEK_IMAGE_CAPABILITY)}
`/qwen`: **{GPT.QWEN_MODEL_NAME}** {image_emoji(GPT.QWEN_IMAGE_CAPABILITY)}
`/doubao`: **{GPT.DOUBAO_MODEL_NAME}** {image_emoji(GPT.DOUBAO_IMAGE_CAPABILITY)}
@@ -55,16 +55,16 @@ def is_gpt_conversation(message: Message) -> bool:
reply_info = parse_msg(reply_msg, silent=True)
model_names = [
GPT.OPENAI_MODEL_NAME,
- GPT.GEMINI_MODEL_NAME,
GPT.DEEPSEEK_MODEL_NAME,
GPT.QWEN_MODEL_NAME,
GPT.DOUBAO_MODEL_NAME,
GPT.GROK_MODEL_NAME,
GPT.TEXT_MODEL_NAME,
GPT.IMAGE_MODEL_NAME,
- AIGC.IMG_MODEL_NAME,
+ GEMINI.TEXT_MODEL_NAME,
+ GEMINI.IMG_MODEL_NAME,
]
- return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names] + [f"🌠{x}".lower() for x in model_names])
+ return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
@@ -92,7 +92,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
reply_text = reply_info["text"]
force_model, modality = parse_force_model(info["text"], reply_text)
-
# cache media_group message, only process once
if media_group_id := message.media_group_id:
if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -103,12 +102,11 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
context_type = get_context_type(conversations)
contexts = await get_conversation_contexts(client, conversations)
config = get_gpt_config(context_type["type"], contexts, force_model)
+ if any("gemini" in x.lower() for x in [config["completions"]["model"], config["friendly_name"]]):
+ return await gemini_response(client, message, contexts, modality, **kwargs)
if not config["client"]["api_key"]:
logger.error(f"⚠️**{config['friendly_name']}** 未配置API Key")
return await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
- if "gemini" in config["completions"]["model"].lower():
- model_name = AIGC.IMG_MODEL_NAME if startswith_prefix(info["text"], prefix=[PREFIX.GENIMG]) else config["friendly_name"]
- return await gemini_response(client, message, contexts, config["completions"]["model"], model_name, modality, **kwargs)
msg = f"🤖**{config['friendly_name']}**: 思考中...\n{clean_prefix(info['text'])}"
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
src/llm/hooks.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-from config import GPT
+from config import GEMINI, GPT
from llm.prompts import modify_prompts, refine_prompts
from utils import unicode_to_ascii
@@ -8,9 +8,8 @@ from utils import unicode_to_ascii
def pre_hooks(client: dict, completions: dict, message_info: dict | None = None):
pre_openrouter_hook(client, completions)
pre_helicone_hook(client, message_info)
- # Gemini tends to respond in English, even when the user's query is in another language.
- if GPT.GEMINI_PREFER_LANG and "gemini" in completions["model"].lower():
- modify_prompts(completions["messages"], prompt=f"请使用{GPT.GEMINI_PREFER_LANG}回复。", role="system", method="append")
+ if GEMINI.PREFER_LANG and "gemini" in completions["model"].lower():
+ modify_prompts(completions["messages"], prompt=f"请使用{GEMINI.PREFER_LANG}回复。", role="system", method="append")
completions["messages"] = refine_prompts(completions["messages"])
src/llm/models.py
@@ -4,7 +4,7 @@
from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
-from config import AIGC, GPT, PREFIX, PROXY
+from config import GEMINI, GPT, PREFIX, PROXY
from messages.parser import parse_msg
from messages.utils import startswith_prefix
@@ -37,8 +37,6 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
# parse from bot reply
if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
force_model = GPT.OPENAI_MODEL
- elif reply_text.startswith(f"🤖{GPT.GEMINI_MODEL_NAME}"):
- force_model = GPT.GEMINI_MODEL
elif reply_text.startswith(f"🤖{GPT.DEEPSEEK_MODEL_NAME}"):
force_model = GPT.DEEPSEEK_MODEL
elif reply_text.startswith(f"🤖{GPT.QWEN_MODEL_NAME}"):
@@ -47,14 +45,12 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
force_model = GPT.DOUBAO_MODEL
elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
force_model = GPT.GROK_MODEL
- elif reply_text.startswith(f"🌠{AIGC.IMG_MODEL_NAME}"):
- force_model = AIGC.IMG_MODEL
+ elif reply_text.startswith(f"🤖{GEMINI.IMG_MODEL_NAME}"):
+ force_model = GEMINI.IMG_MODEL
modality = "image"
# parse from command prefix
if startswith_prefix(text, prefix=["/gpt"]):
force_model = GPT.OPENAI_MODEL
- elif startswith_prefix(text, prefix=["/gemini"]):
- force_model = GPT.GEMINI_MODEL
elif startswith_prefix(text, prefix=["/ds"]):
force_model = GPT.DEEPSEEK_MODEL
elif startswith_prefix(text, prefix=["/qwen"]):
@@ -64,8 +60,11 @@ def parse_force_model(text: str, reply_text: str) -> tuple[str, str]:
elif startswith_prefix(text, prefix=["/grok"]):
force_model = GPT.GROK_MODEL
elif startswith_prefix(text, prefix=[PREFIX.GENIMG]):
- force_model = AIGC.IMG_MODEL
+ force_model = GEMINI.IMG_MODEL
modality = "image"
+ elif startswith_prefix(text, prefix=["/gemini"]):
+ force_model = GEMINI.TEXT_MODEL
+ modality = "text"
return force_model, modality
@@ -103,7 +102,6 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
# align with force model
model_factory = {
GPT.OPENAI_MODEL: {"api_key": GPT.OPENAI_API_KEY, "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
- GPT.GEMINI_MODEL: {"api_key": GPT.GEMINI_API_KEY, "base_url": GPT.GEMINI_BASE_URL, "model_name": GPT.GEMINI_MODEL_NAME},
GPT.DEEPSEEK_MODEL: {"api_key": GPT.DEEPSEEK_API_KEY, "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
GPT.QWEN_MODEL: {"api_key": GPT.QWEN_API_KEY, "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
GPT.DOUBAO_MODEL: {"api_key": GPT.DOUBAO_API_KEY, "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
@@ -119,7 +117,6 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
model_type == "image" # check capabilities
and (
(force_model == GPT.OPENAI_MODEL and GPT.OPENAI_IMAGE_CAPABILITY)
- or (force_model == GPT.GEMINI_MODEL and GPT.GEMINI_IMAGE_CAPABILITY)
or (force_model == GPT.DEEPSEEK_MODEL and GPT.DEEPSEEK_IMAGE_CAPABILITY)
or (force_model == GPT.QWEN_MODEL and GPT.QWEN_IMAGE_CAPABILITY)
or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
src/config.py
@@ -142,7 +142,7 @@ class COOKIE: # See: https://github.com/easychen/CookieCloud
class GPT: # see `llm/README.md`
- # See class AIGC for the AIGC configurations
+ # See class GEMINI for the GEMINI configurations
STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
@@ -173,15 +173,6 @@ class GPT: # see `llm/README.md`
MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
HELICONE_API_KEY = os.getenv("HELICONE_API_KEY", "")
- # comma separated reasoning models, add system prompt to the models to ensure the output format.
- REASONING_MODELS = os.getenv("GPT_REASONING_MODELS", "") # deprecated, we do not need this anymore
- # /gemini command
- GEMINI_MODEL = os.getenv("GPT_GEMINI_MODEL", "gemini-2.0-flash")
- GEMINI_MODEL_NAME = os.getenv("GPT_GEMINI_MODEL_NAME", "Gemini-2.0-Flash")
- GEMINI_API_KEY = os.getenv("GPT_GEMINI_API_KEY", "")
- GEMINI_BASE_URL = os.getenv("GPT_GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai")
- GEMINI_IMAGE_CAPABILITY = os.getenv("GPT_GEMINI_IMAGE_CAPABILITY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
- GEMINI_PREFER_LANG = os.getenv("GPT_GEMINI_PREFER_LANG", "") # Set a prefer response language for Gemini
# /gpt command
OPENAI_MODEL = os.getenv("GPT_OPENAI_MODEL", "gpt-4o")
OPENAI_MODEL_NAME = os.getenv("GPT_OPENAI_MODEL_NAME", "GPT-4o")
@@ -257,11 +248,17 @@ class ASR:
TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
-class AIGC:
+class GEMINI: # Official Gemini
# https://ai.google.dev/gemini-api/docs/image-generation
- IMG_BASR_URL = os.getenv("AIGC_IMG_BASR_URL", "https://generativelanguage.googleapis.com/")
- IMG_API_KEY = os.getenv("AIGC_IMG_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
- IMG_MODEL = os.getenv("AIGC_IMG_MODEL", "gemini-2.0-flash-exp")
- IMG_MODEL_NAME = os.getenv("AIGC_IMG_MODEL_NAME", "Gemini-2.0-Flash")
- IMG_PROXY = os.getenv("AIGC_IMG_PROXY", None)
- IMG_MAX_PROMPT_TOKEN = int(os.getenv("AIGC_IMG_MAX_PROMPT_TOKEN", "480"))
+ BASR_URL = os.getenv("GEMINI_BASR_URL", "https://generativelanguage.googleapis.com/")
+ API_KEYS = os.getenv("GEMINI_API_KEYS", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
+ PROXY = os.getenv("GEMINI_PROXY", None)
+ PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "") # Set a prefer response language for Gemini
+
+ # response modality: text
+ TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-pro-exp-03-25")
+ TEXT_MODEL_NAME = os.getenv("GEMINI_TEXT_MODEL_NAME", "Gemini-2.5-Pro")
+
+ # response modality: image
+ IMG_MODEL = os.getenv("GEMINI_IMG_MODEL", "gemini-2.0-flash-exp")
+ IMG_MODEL_NAME = os.getenv("GEMINI_IMG_MODEL_NAME", "Gemini-2.0-Flash")
src/handler.py
@@ -306,7 +306,7 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
if permission["ai"]:
msg += f"\n🤖**AI对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok`"
msg += f"\n📖**AI总结**: `{PREFIX.AI_SUMMARY}` 总结历史聊天记录"
- msg += f"\n🌠**AIGC**: `{PREFIX.GENIMG}`"
+ msg += f"\n🌠**AI生图**: `{PREFIX.GENIMG}`"
if permission["asr"]:
msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
if permission["audio"]: