Commit 723d854
Changed files (12)
src
src/llm/ali/text2img.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+from pathlib import Path
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import TEXT2IMG
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from networking import download_file, hx_req
+from utils import strings_list
+
+
+async def ali_text2img(client: Client, message: Message, model_id: str, prompt: str, *, silent: bool = False, **kwargs) -> dict:
+ """Ali text to image.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ prompt (str): Prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+
+ Return:
+ {"error": str}
+ """
+ if not prompt:
+ if message.reply_to_message:
+ prompt = message.reply_to_message.content
+ else:
+ await message.reply(text="请输入图片描述。", quote=True)
+ return {}
+ model_name = model_id.split("/")[-1].title()
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n💬提示词: {prompt}", **kwargs))[0]
+ error = ""
+ payload = {"model": model_id, "input": {"prompt": prompt}}
+ if "stable-diffusion" in model_id:
+ payload |= {"parameters": {"n": 4}}
+ for api_key in strings_list(TEXT2IMG.ALI_API_KEY, shuffle=True):
+ headers = {
+ "X-DashScope-Async": "enable",
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ }
+ resp = await hx_req(
+ "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis",
+ method="POST",
+ headers=headers,
+ json_data=payload,
+ timeout=10,
+ check_kv={"output.task_status": "PENDING"},
+ check_keys=["output.task_id"],
+ proxy=TEXT2IMG.ALI_PROXY,
+ )
+ if resp.get("message"):
+ error = resp["message"]
+ logger.error(error)
+ continue
+ finished = await wait_for_response(resp["output"]["task_id"], api_key)
+ if images := finished.get("images"):
+ media = [{"photo": img} for img in images]
+ await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=media, **kwargs)
+ break
+ await modify_progress(del_status=True, **kwargs)
+ return {"error": error} if error else {}
+
+
+async def wait_for_response(task_id: str, api_key: str) -> dict:
+ api = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
+ headers = {"Authorization": f"Bearer {api_key}"}
+ resp = await hx_req(api, headers=headers, silent=True, proxy=TEXT2IMG.ALI_PROXY, check_keys=["output.task_status"])
+ task_status = resp["output"]["task_status"]
+ if task_status == "FAILED":
+ error = glom(resp, "output.message", default="")
+ return {"error": error}
+ if task_status == "SUCCESS":
+ resp = await hx_req(api, headers=headers, silent=True, proxy=TEXT2IMG.ALI_PROXY, check_keys=["output.result"])
+ return resp["output"]["result"]
+ while task_status == "RUNNING":
+ await asyncio.sleep(1)
+ logger.trace(f"Waiting for Ali Text2IMG, TaskID: {task_id}")
+ resp = await hx_req(api, headers=headers, silent=True, proxy=TEXT2IMG.ALI_PROXY, check_keys=["output.task_status"])
+ task_status = resp["output"]["task_status"]
+ if task_status == "SUCCEEDED":
+ img_urls = glom(resp, "output.results.*.url", default=[])
+ tasks = [download_file(url, proxy=TEXT2IMG.ALI_PROXY) for url in img_urls]
+ paths = await asyncio.gather(*tasks)
+ if all(Path(path).is_file() for path in paths):
+ return {"images": paths}
+ return {}
src/llm/cloudflare/text2img.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from pathlib import Path
+
+import anyio
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import DOWNLOAD_DIR, TEXT2IMG
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from networking import hx_req
+from utils import rand_string, strings_list
+
+
+async def cloudflare_text2img(client: Client, message: Message, model_id: str, prompt: str, *, silent: bool = False, **kwargs) -> dict:
+ """Cloudflare text to image.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ prompt (str): Prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+ """
+ if not prompt:
+ if message.reply_to_message:
+ prompt = message.reply_to_message.content
+ else:
+ await message.reply(text="请输入图片描述。", quote=True)
+ return {}
+
+ model_name = model_id.split("/")[-1].title()
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=f"🌠**{model_name}**:\n💬提示词: {prompt}", **kwargs))[0]
+ for api_key in strings_list(TEXT2IMG.CF_API_KEY, shuffle=True):
+ account_id, token = api_key.split(":")
+ resp = await hx_req(
+ f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_id}",
+ method="POST",
+ headers={"Authorization": f"Bearer {token}"},
+ json_data={"prompt": prompt},
+ timeout=300,
+ proxy=TEXT2IMG.CF_PROXY,
+ rformat="content",
+ )
+ if data := resp.get("content"):
+ path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.png"
+ async with await anyio.open_file(path, "wb") as f:
+ await f.write(data)
+ await send2tg(client, message, texts=f"{prompt}\n(By **{model_name}**)", media=[{"photo": path}], **kwargs)
+ break
+ await modify_progress(del_status=True, **kwargs)
+ return {}
src/llm/gemini/chat.py
@@ -0,0 +1,234 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import json
+
+from google import genai
+from google.genai import types
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message, ReplyParameters
+
+from config import GEMINI, GPT, TEXT_LENGTH
+from llm.contexts import get_conversation_contexts, get_conversations
+from llm.gemini.utils import add_grounding_results, gemini_logging, parse_response
+from llm.hooks import hook_gemini_httpoptions
+from llm.utils import BOT_TIPS, REASONING_BEGIN, REASONING_END, beautify_llm_response, clean_cmd_prefix, shuffle_keys
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import blockquote, count_without_entities, smart_split
+
+
+async def gemini_chat_completion(
+ client: Client,
+ message: Message,
+ *,
+ enable_tools: bool = True,
+ append_grounding: bool = True,
+ disable_thinking: bool = False,
+ include_thoughts: bool = True,
+ system_prompt: str | None = None,
+ silent: bool = False,
+ **kwargs,
+) -> dict:
+ r"""Get Gemini response.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ enable_tools (bool, optional): Whether to enable tools. Defaults to True.
+ append_grounding (bool, optional): Whether to append grounding to the response. Defaults to True.
+ disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
+ include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
+ system_prompt (str | None, optional): System prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+ """
+ info = parse_msg(message, silent=True, use_cache=False)
+ tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())]
+ # parse config from environment variable
+ genconfig = {}
+ with contextlib.suppress(Exception):
+ extra_config_str = GEMINI.TEXT_CONFIG
+ genconfig = json.loads(extra_config_str)
+ try:
+ real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+ msg = f"🤖**{GEMINI.TEXT_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
+ genconfig |= {"response_modalities": ["TEXT"]}
+ if enable_tools:
+ genconfig |= {"tools": tools}
+ if system_prompt is not None:
+ genconfig |= {"system_instruction": system_prompt}
+ elif GEMINI.PREFER_LANG:
+ genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
+
+ if GEMINI.TEXT_THINKING_BUDGET is not None and not disable_thinking:
+ thinking_budget = min(round(float(GEMINI.TEXT_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+ genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
+ params = {"model": GEMINI.TEXT_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
+ logger.trace(params)
+ return await gemini_stream(client, message, GEMINI.TEXT_MODEL_NAME, params, append_grounding=append_grounding, silent=silent, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ return {}
+
+
+async def gemini_stream(
+ client: Client,
+ message: Message,
+ model_name: str,
+ params: dict,
+ prefix: str | None = None,
+ retry: int = 0,
+ max_retry: int | None = None,
+ last_error: str = "",
+ *,
+ silent: bool = False,
+ append_grounding: bool = True,
+ single_thinking_msg: bool = True,
+ remove_thinking: bool = True,
+ **kwargs,
+) -> dict:
+ """Gemini stream response.
+
+ Args:
+ single_thinking_msg (bool, optional): Only use one message for displaying thinking.
+ remove_thinking (bool, optional): Remove thinking parts once finished.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+ """
+ if prefix is None:
+ prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
+ answers = "" # all model responses
+ thoughts = "" # all model thoughts
+ runtime_texts = "" # for a single telegram message
+ init_status_msg = None if silent else kwargs.get("progress")
+ status_msg = init_status_msg
+ status_mid = status_msg.id if isinstance(status_msg, Message) else message.id
+ if not kwargs.get("gemini_api_keys"):
+ kwargs["gemini_api_keys"] = shuffle_keys(GEMINI.API_KEY)
+ api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
+ max_retry = len(api_keys) - 1 if max_retry is None else max_retry
+ resp = {}
+ sent_messages = []
+ try:
+ if retry > min(len(api_keys) - 1, max_retry):
+ logger.error(f"[Gemini] Failed after {retry} retries")
+ await modify_progress(message=init_status_msg, text=last_error, force_update=True)
+ return {"error": last_error}
+ api_key = kwargs.get("gemini_api_key", api_keys[retry])
+ http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ # Construct the request params
+ if "conversations" in params: # convert conversations to contents
+ params["contents"] = await get_conversation_contexts(client, params["conversations"], ctx_format="gemini", app=app)
+ gemini_logging(params["contents"])
+ tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
+ num_tokens = tokens.total_tokens or 0
+ if num_tokens > GEMINI.TEXT_MAX_TOKEN:
+ logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
+ params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
+ params["config"].thinking_config = None
+ is_reasoning = False
+ is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
+ genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
+ length = 0
+ async for chunk in await app.aio.models.generate_content_stream(**genai_params):
+ resp = parse_response(chunk.model_dump())
+ answer = resp.get("texts", "")
+ thinking = resp.get("thinking", "")
+ if is_reasoning_conversation is None and thinking:
+ is_reasoning_conversation = True
+
+ if thinking and not is_reasoning: # First time receiving reasoning content
+ is_reasoning = True
+ runtime_texts += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{thinking.lstrip()}"
+ elif thinking and is_reasoning: # Receiving reasoning content and is reasoning
+ runtime_texts += thinking
+ elif is_reasoning_conversation is True and is_reasoning: # Receiving response, close reasoning flag
+ is_reasoning = False
+ runtime_texts = answer.lstrip() if remove_thinking else f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+ else:
+ runtime_texts += answer
+
+ thoughts += thinking
+ answers += answer
+ runtime_texts = beautify_llm_response(runtime_texts)
+ length = await count_without_entities(prefix + runtime_texts)
+ if length <= TEXT_LENGTH:
+ if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
+ await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
+ else: # answers is too long, split it into multiple messages
+ parts = await smart_split(prefix + runtime_texts)
+ if len(parts) == 1:
+ continue
+ if is_reasoning and single_thinking_msg:
+ runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{parts[-1].lstrip()}" # remove previous thinking
+ await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
+ else:
+ await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
+ runtime_texts = parts[-1] # keep the last part
+ if is_reasoning:
+ runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
+ if not silent:
+ status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
+ sent_messages.append(status_msg)
+ status_mid = status_msg.id
+
+ # all chunks are processed
+ if not answers.strip() and not thoughts.strip(): # empty response
+ return await gemini_stream(
+ client,
+ message,
+ model_name,
+ params,
+ prefix=prefix,
+ retry=retry + 1,
+ last_error=last_error,
+ silent=silent,
+ append_grounding=append_grounding,
+ **kwargs,
+ )
+ if append_grounding: # add grounding to the response
+ answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
+ runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
+ final_thoughts = "" if remove_thinking else thoughts
+ if await count_without_entities(prefix + final_thoughts + answers) <= TEXT_LENGTH - 10: # short answer in single msg
+ if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
+ quoted = REASONING_BEGIN + final_thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if final_thoughts.strip() else answers.strip()
+ await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
+ else:
+ quoted = blockquote(REASONING_BEGIN + final_thoughts.strip() + REASONING_END) + "\n" if final_thoughts.strip() else ""
+ await modify_progress(message=status_msg, text=f"{prefix}{quoted}{answers}", force_update=True)
+ # total length is too long, answers are splitted into multiple messages
+ elif length > GPT.COLLAPSE_LENGTH:
+ await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
+ else:
+ await modify_progress(message=status_msg, text=prefix + runtime_texts, force_update=True)
+
+ except Exception as e:
+ error = str(e)
+ if "resp" in locals():
+ error += f"\n{resp}"
+ logger.error(error)
+ with contextlib.suppress(Exception):
+ await modify_progress(message=init_status_msg, text=error, force_update=True)
+ [await modify_progress(msg, del_status=True) for msg in sent_messages]
+ return await gemini_stream(
+ client,
+ message,
+ model_name,
+ params,
+ prefix=prefix,
+ retry=retry + 1,
+ last_error=error,
+ silent=silent,
+ append_grounding=append_grounding,
+ **kwargs,
+ )
+ return {"texts": answers, "thoughts": thoughts, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
src/llm/gemini/text2img.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import json
+
+from google import genai
+from google.genai import types
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import CAPTION_LENGTH, GEMINI, GPT, TEXT_LENGTH
+from llm.contexts import get_conversation_contexts, get_conversations
+from llm.gemini.utils import parse_response
+from llm.hooks import hook_gemini_httpoptions
+from llm.utils import BOT_TIPS, clean_cmd_prefix, clean_gemini_sourcemarks
+from messages.parser import parse_msg
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import blockquote, count_without_entities, smart_split
+from utils import strings_list
+
+
+async def gemini_text2img(
+ client: Client,
+ message: Message,
+ *,
+ disable_thinking: bool = False,
+ system_prompt: str | None = None,
+ silent: bool = False,
+ **kwargs,
+) -> dict:
+ """Gemini text to image.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
+ include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
+ system_prompt (str | None, optional): System prompt. Defaults to None.
+ silent (bool, optional): Whether to disable progressing. Defaults to False.
+ """
+ info = parse_msg(message, silent=True, use_cache=False)
+ # parse config from environment variable
+ genconfig = {}
+ with contextlib.suppress(Exception):
+ extra_config_str = GEMINI.IMG_CONFIG
+ genconfig = json.loads(extra_config_str)
+ try:
+ real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+ msg = f"🤖**{GEMINI.IMG_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
+ if not silent and kwargs.get("show_progress"):
+ kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
+ genconfig |= {"response_modalities": ["TEXT", "IMAGE"]}
+ if system_prompt is not None:
+ genconfig |= {"system_instruction": system_prompt}
+ if GEMINI.IMG_THINKING_BUDGET is not None and not disable_thinking:
+ thinking_budget = min(round(float(GEMINI.IMG_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+ genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+ params = {"model": GEMINI.IMG_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
+ logger.trace(params)
+ return await gemini_non_stream(client, message, GEMINI.IMG_MODEL_NAME, params, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ return {}
+
+
+async def gemini_non_stream(
+ client: Client,
+ message: Message,
+ model_name: str,
+ params: dict,
+ retry: int = 0,
+ **kwargs,
+) -> dict:
+ """Gemini non-stream response.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+ """
+ results = {}
+ try:
+ api_keys = strings_list(kwargs.get("gemini_api_keys", GEMINI.API_KEY))
+ if retry > len(api_keys) - 1:
+ return {}
+ api_key = kwargs.get("gemini_api_key", api_keys[retry])
+ http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ # Construct the request params
+ if "conversations" in params: # convert conversations to contents
+ params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
+ clean_gemini_sourcemarks(params["contents"])
+ genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
+ response = await app.aio.models.generate_content(**genai_params)
+ prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
+ res = parse_response(response.model_dump())
+ texts = res.get("texts", "")
+ results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": ""}
+ media = res.get("media", [])
+ total = prefix + texts.strip()
+ length = await count_without_entities(total)
+ single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
+ if length <= GPT.COLLAPSE_LENGTH:
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
+ elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
+ final = prefix + blockquote(texts.strip())
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
+ else: # multiple messages
+ for idx, txt in await smart_split(total, single_msg_length):
+ if idx == 0:
+ results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
+ else:
+ results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
+ await modify_progress(del_status=True, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ error = str(e)
+ if "res" in locals():
+ error += f"\n{res}" # type: ignore
+ if "response" in locals():
+ error += f"\n{response}" # type: ignore
+ await modify_progress(text=error, force_update=True, **kwargs)
+ return await gemini_non_stream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
+ return results
src/llm/gemini/utils.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import contextlib
+from io import BytesIO
+from pathlib import Path
+
+from glom import glom
+from google.genai import types
+from loguru import logger
+from PIL import Image
+
+from config import DOWNLOAD_DIR
+from llm.utils import beautify_llm_response, clean_source_marks
+from networking import flatten_rediercts
+from utils import number_to_emoji, rand_string
+
+
+def parse_response(data: dict) -> dict:
+ """Parse gemini response, includes texts, image and websearch."""
+ parts = glom(data, "candidates.0.content.parts", default=[]) or []
+ gemini_logging(parts)
+ grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
+ grounding_supports = glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or []
+ texts = ""
+ thinking = ""
+ media = []
+ for item in parts:
+ if item.get("text") is not None:
+ if item.get("thought"):
+ thinking += item["text"]
+ else:
+ texts += item["text"]
+ if item.get("inline_data") is not None:
+ image = Image.open(BytesIO(item["inline_data"]["data"]))
+ mime = item["inline_data"]["mime_type"]
+ ext = mime.split("/")[-1]
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
+ image.save(save_path)
+ media.append({"photo": save_path})
+ return {
+ "texts": beautify_llm_response(texts, newline_level=2),
+ "thinking": beautify_llm_response(thinking, newline_level=2),
+ "media": media,
+ "grounding_chunks": grounding_chunks,
+ "grounding_supports": grounding_supports,
+ }
+
+
+async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
+ urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
+ tasks = [flatten_rediercts(url) for url in urls]
+ try:
+ flatten_urls = await asyncio.gather(*tasks)
+ index2url = flatten_urls
+ except Exception as e:
+ logger.warning(e)
+ index2url = urls
+ for support in grounding_supports:
+ indices: list[int] = support.get("grounding_chunk_indices", [])
+ indices_with_url = " ".join([f"[[{idx + 1}]]({index2url[idx]})" for idx in indices])
+ if segment := glom(support, "segment.text", default=""):
+ answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
+ for idx, grounding in enumerate(grounding_chunks):
+ if idx > 9:
+ break
+ title = glom(grounding, "web.title", default="Web")
+ url = index2url[idx]
+ if url in answers:
+ answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
+ return answers
+
+
+def gemini_logging(contexts: list):
+ """Print logs of gemini contexts."""
+ msg = ""
+ with contextlib.suppress(Exception):
+ for item in contexts:
+ if isinstance(item, str):
+ msg += f"{item}\n"
+ continue
+ if isinstance(item, types.File):
+ msg += f"[{item.mime_type}]: {item.name}\n"
+ continue
+ if not isinstance(item, dict):
+ continue
+ role = item.get("role", "").upper() or "MODEL"
+
+ # Request
+ for part in item.get("parts", []):
+ if part.inline_data:
+ msg += f"[{role}]: Blob_Data "
+ if part.text:
+ msg += f"[{role}]: {part.text} "
+ # Response
+ if item.get("text", ""):
+ msg += f"[{role}]: {item['text']} "
+ if item.get("inline_data", ""):
+ msg += f"[{role}]: Blob_Data "
+
+ logger.debug(f"{msg!r}")
+
+
+def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> types.ContentUnionDict:
+ r"""(Deprecated) Convert OpenAI context to Gemini format.
+
+ Not needed anymore.
+
+ Args:
+ context (dict): {
+ "role": role, # assistant or user
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
+ {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
+ ]
+ }
+
+ Returns:
+ dict: {
+ "role": role, # model or user
+ "parts: [
+ {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
+ {"text": "hello"}
+ ]
+ }
+ """
+ parts: list[types.Part] = []
+ role = "model" if context["role"] == "assistant" else "user"
+ for item in context["content"]:
+ if item["type"] == "text":
+ if keep_marks:
+ parts.append(types.Part.from_text(text=item["text"]))
+ else:
+ parts.append(types.Part.from_text(text=clean_source_marks(item["text"])))
+ elif item["type"] == "image_url":
+ data = item["image_url"]["url"].split(";base64,")
+ mime = data[0].removeprefix("data:")
+ parts.append(types.Part.from_bytes(mime_type=mime, data=data[1]))
+ return {"role": role, "parts": parts} # type: ignore
src/llm/gemini.py
@@ -1,466 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-import asyncio
-import contextlib
-import json
-from io import BytesIO
-from pathlib import Path
-
-from glom import glom
-from google import genai
-from google.genai import types
-from loguru import logger
-from PIL import Image
-from pyrogram.client import Client
-from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
-from pyrogram.types import Message, ReplyParameters
-
-from config import CAPTION_LENGTH, DOWNLOAD_DIR, GEMINI, GPT, PREFIX, TEXT_LENGTH
-from llm.contexts import get_conversation_contexts
-from llm.hooks import hook_gemini_httpoptions
-from llm.utils import (
- BOT_TIPS,
- REASONING_BEGIN,
- REASONING_END,
- beautify_llm_response,
- clean_cmd_prefix,
- clean_gemini_sourcemarks,
- clean_source_marks,
- shuffle_keys,
-)
-from messages.parser import parse_msg
-from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import blockquote, count_without_entities, smart_split
-from networking import flatten_rediercts
-from utils import number_to_emoji, rand_string
-
-HELP = f"""🌠**AI生图**
-`{PREFIX.GENIMG}` 后接提示词即可生成
-回复消息可继续对话重新修改生成结果
-
-⚙️模型配置:
-🌠生图模型: **{GEMINI.IMG_MODEL}
-
-⚠️目前只支持生成图片
-"""
-
-
-async def gemini_response(
- client: Client,
- message: Message,
- conversations: list[Message],
- modality: str = "image",
- *,
- enable_tools: bool = True,
- append_grounding: bool = True,
- disable_thinking: bool = False,
- include_thoughts: bool = True,
- system_prompt: str | None = None,
- silent: bool = False,
- **kwargs,
-) -> dict:
- r"""Get Gemini response.
-
- Args:
- client (Client): The Pyrogram client.
- message (Message): The trigger message object.
- conversations (list[Message]): list of chat conversations.
- modality (str): response modality
- enable_tools (bool, optional): Whether to enable tools. Defaults to True.
- append_grounding (bool, optional): Whether to append grounding to the response. Defaults to True.
- disable_thinking (bool, optional): Whether to disable thinking. Defaults to False.
- include_thoughts (bool, optional): Whether to include thoughts. Defaults to True.
- system_prompt (str | None, optional): System prompt. Defaults to None.
- silent (bool, optional): Whether to disable progressing. Defaults to False.
- """
- info = parse_msg(message, silent=True, use_cache=False)
- model = GEMINI.TEXT_MODEL if modality == "text" else GEMINI.IMG_MODEL
- model_name = GEMINI.TEXT_MODEL_NAME if modality == "text" else GEMINI.IMG_MODEL_NAME
- if not GEMINI.API_KEY:
- await send2tg(client, message, texts="⚠️**未配置Gemini API, 请尝试其他模型", **kwargs)
- response_modalities = ["TEXT", "IMAGE"] if modality == "image" else ["TEXT"]
- thinking_budget = GEMINI.IMG_THINKING_BUDGET if modality == "image" else GEMINI.TEXT_THINKING_BUDGET
- tools = [types.Tool(url_context=types.UrlContext()), types.Tool(google_search=types.GoogleSearch())]
- # parse config from environment variable
- genconfig = {}
- with contextlib.suppress(Exception):
- extra_config_str = GEMINI.IMG_CONFIG if modality == "image" else GEMINI.TEXT_CONFIG
- genconfig = json.loads(extra_config_str)
- try:
- real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
- msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
- if not silent and kwargs.get("show_progress"):
- kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
- genconfig |= {"response_modalities": response_modalities}
- if enable_tools and modality == "text":
- genconfig |= {"tools": tools}
- if system_prompt is not None:
- genconfig |= {"system_instruction": system_prompt}
- elif GEMINI.PREFER_LANG and modality == "text":
- genconfig |= {"system_instruction": f"请优先使用{GEMINI.PREFER_LANG}思考和回复"}
-
- if thinking_budget is not None and not disable_thinking:
- thinking_budget = min(round(float(thinking_budget)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
- params = {"model": model, "conversations": conversations, "config": types.GenerateContentConfig(**genconfig)}
- logger.trace(params)
- if modality == "image":
- return await gemini_nonstream(client, message, model_name, params, clean_marks=True, append_grounding=append_grounding, **kwargs)
- return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, silent=silent, **kwargs)
- except Exception as e:
- logger.error(e)
- return {}
-
-
-async def gemini_stream(
- client: Client,
- message: Message,
- model_name: str,
- params: dict,
- prefix: str | None = None,
- retry: int = 0,
- max_retry: int | None = None,
- last_error: str = "",
- *,
- silent: bool = False,
- append_grounding: bool = True,
- single_thinking_msg: bool = True,
- remove_thinking: bool = True,
- **kwargs,
-) -> dict:
- """Gemini stream response.
-
- Args:
- single_thinking_msg (bool, optional): Only use one message for displaying thinking.
- remove_thinking (bool, optional): Remove thinking parts once finished.
-
- Returns:
- dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
- """
- if prefix is None:
- prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- answers = "" # all model responses
- thoughts = "" # all model thoughts
- runtime_texts = "" # for a single telegram message
- init_status_msg = None if silent else kwargs.get("progress")
- status_msg = init_status_msg
- status_mid = status_msg.id if isinstance(status_msg, Message) else message.id
- if not kwargs.get("gemini_api_keys"):
- kwargs["gemini_api_keys"] = shuffle_keys(GEMINI.API_KEY)
- api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
- max_retry = len(api_keys) - 1 if max_retry is None else max_retry
- resp = {}
- sent_messages = []
- try:
- if retry > min(len(api_keys) - 1, max_retry):
- logger.error(f"[Gemini] Failed after {retry} retries")
- await modify_progress(message=init_status_msg, text=last_error, force_update=True)
- return {"error": last_error}
- api_key = kwargs.get("gemini_api_key", api_keys[retry])
- http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
- # Construct the request params
- if "conversations" in params: # convert conversations to contents
- params["contents"] = await get_conversation_contexts(client, params["conversations"], ctx_format="gemini", app=app)
- gemini_logging(params["contents"])
- tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
- num_tokens = tokens.total_tokens or 0
- if num_tokens > GEMINI.TEXT_MAX_TOKEN:
- logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
- params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
- params["config"].thinking_config = None
- is_reasoning = False
- is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
- genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
- length = 0
- async for chunk in await app.aio.models.generate_content_stream(**genai_params):
- resp = parse_response(chunk.model_dump())
- answer = resp.get("texts", "")
- thinking = resp.get("thinking", "")
- if is_reasoning_conversation is None and thinking:
- is_reasoning_conversation = True
-
- if thinking and not is_reasoning: # First time receiving reasoning content
- is_reasoning = True
- runtime_texts += f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{thinking.lstrip()}"
- elif thinking and is_reasoning: # Receiving reasoning content and is reasoning
- runtime_texts += thinking
- elif is_reasoning_conversation is True and is_reasoning: # Receiving response, close reasoning flag
- is_reasoning = False
- runtime_texts = answer.lstrip() if remove_thinking else f"{runtime_texts.rstrip()}{REASONING_END}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
- else:
- runtime_texts += answer
-
- thoughts += thinking
- answers += answer
- runtime_texts = beautify_llm_response(runtime_texts)
- length = await count_without_entities(prefix + runtime_texts)
- if length <= TEXT_LENGTH:
- if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
- await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
- else: # answers is too long, split it into multiple messages
- parts = await smart_split(prefix + runtime_texts)
- if len(parts) == 1:
- continue
- if is_reasoning and single_thinking_msg:
- runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{parts[-1].lstrip()}" # remove previous thinking
- await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
- else:
- await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
- runtime_texts = parts[-1] # keep the last part
- if is_reasoning:
- runtime_texts = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{REASONING_BEGIN}{runtime_texts.lstrip()}"
- if not silent:
- status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
- sent_messages.append(status_msg)
- status_mid = status_msg.id
-
- # all chunks are processed
- if not answers.strip() and not thoughts.strip(): # empty response
- return await gemini_stream(
- client,
- message,
- model_name,
- params,
- prefix=prefix,
- retry=retry + 1,
- last_error=last_error,
- silent=silent,
- append_grounding=append_grounding,
- **kwargs,
- )
- if append_grounding: # add grounding to the response
- answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
- runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
- final_thoughts = "" if remove_thinking else thoughts
- if await count_without_entities(prefix + final_thoughts + answers) <= TEXT_LENGTH - 10: # short answer in single msg
- if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
- quoted = REASONING_BEGIN + final_thoughts.strip() + REASONING_END + "\n\n" + answers.strip() if final_thoughts.strip() else answers.strip()
- await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
- else:
- quoted = blockquote(REASONING_BEGIN + final_thoughts.strip() + REASONING_END) + "\n" if final_thoughts.strip() else ""
- await modify_progress(message=status_msg, text=f"{prefix}{quoted}{answers}", force_update=True)
- # total length is too long, answers are splitted into multiple messages
- elif length > GPT.COLLAPSE_LENGTH:
- await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
- else:
- await modify_progress(message=status_msg, text=prefix + runtime_texts, force_update=True)
-
- except Exception as e:
- error = str(e)
- if "resp" in locals():
- error += f"\n{resp}"
- logger.error(error)
- with contextlib.suppress(Exception):
- await modify_progress(message=init_status_msg, text=error, force_update=True)
- [await modify_progress(msg, del_status=True) for msg in sent_messages]
- return await gemini_stream(
- client,
- message,
- model_name,
- params,
- prefix=prefix,
- retry=retry + 1,
- last_error=error,
- silent=silent,
- append_grounding=append_grounding,
- **kwargs,
- )
- return {"texts": answers, "thoughts": thoughts, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
-
-
-async def gemini_nonstream(
- client: Client,
- message: Message,
- model_name: str,
- params: dict,
- retry: int = 0,
- *,
- clean_marks: bool = False, # useful in image generation
- append_grounding: bool = True,
- **kwargs,
-) -> dict:
- """Gemini non-stream response.
-
- Returns:
- dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
- """
- results = {}
- try:
- api_keys = [x.strip() for x in GEMINI.API_KEY.split(",") if x.strip()]
- if kwargs.get("gemini_api_keys"):
- api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
- if retry > len(api_keys) - 1:
- return {}
- api_key = kwargs.get("gemini_api_key", api_keys[retry])
- http_options = types.HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
- # Construct the request params
- if "conversations" in params: # convert conversations to contents
- params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
- if clean_marks:
- clean_gemini_sourcemarks(params["contents"])
- tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
- num_tokens = tokens.total_tokens or 0
- if num_tokens > GEMINI.TEXT_MAX_TOKEN:
- logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
- params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
- params["config"].thinking_config = None
- params["config"].response_modalities = ["TEXT"]
- genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
- response = await app.aio.models.generate_content(**genai_params)
- prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- res = parse_response(response.model_dump())
- texts = res.get("texts", "")
- thoughts = res.get("thoughts", "")
- if append_grounding: # add grounding to the response
- texts = await add_grounding_results(texts, res["grounding_chunks"], res["grounding_supports"])
- results |= {"prefix": prefix, "model_name": model_name, "texts": texts, "thoughts": thoughts}
- media = res.get("media", [])
- total = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END) + "\n" + texts.strip() if thoughts.strip() else prefix + texts.strip()
- length = await count_without_entities(total)
- single_msg_length = CAPTION_LENGTH if media else TEXT_LENGTH
- if length <= GPT.COLLAPSE_LENGTH:
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=total, media=media, **kwargs)
- elif GPT.COLLAPSE_LENGTH < length <= single_msg_length:
- final = prefix + blockquote(REASONING_BEGIN + thoughts.strip() + REASONING_END + "\n\n" + texts.strip()) if thoughts.strip() else prefix + blockquote(texts.strip())
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=final, media=media, **kwargs)
- else: # multiple messages
- for idx, txt in await smart_split(total, single_msg_length):
- if idx == 0:
- results["sent_message"] = await send2tg(client, message, caption_above=True, texts=txt, media=media, **kwargs)
- else:
- results["sent_message"] = await send2tg(client, message, texts=txt, **kwargs)
- await modify_progress(del_status=True, **kwargs)
- except Exception as e:
- logger.error(e)
- error = str(e)
- if "res" in locals():
- error += f"\n{res}" # type: ignore
- if "response" in locals():
- error += f"\n{response}" # type: ignore
- await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs) # type: ignore
- return results
-
-
-def parse_response(data: dict) -> dict:
- """Parse gemini response, includes texts, image and websearch."""
- parts = glom(data, "candidates.0.content.parts", default=[]) or []
- gemini_logging(parts)
- grounding_chunks = glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or []
- grounding_supports = glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or []
- texts = ""
- thinking = ""
- media = []
- for item in parts:
- if item.get("text") is not None:
- if item.get("thought"):
- thinking += item["text"]
- else:
- texts += item["text"]
- if item.get("inline_data") is not None:
- image = Image.open(BytesIO(item["inline_data"]["data"]))
- mime = item["inline_data"]["mime_type"]
- ext = mime.split("/")[-1]
- save_path = Path(DOWNLOAD_DIR) / f"{rand_string()}.{ext}"
- image.save(save_path)
- media.append({"photo": save_path})
- return {
- "texts": beautify_llm_response(texts, newline_level=2),
- "thinking": beautify_llm_response(thinking, newline_level=2),
- "media": media,
- "grounding_chunks": grounding_chunks,
- "grounding_supports": grounding_supports,
- }
-
-
-async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
- urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
- tasks = [flatten_rediercts(url) for url in urls]
- flatten_urls = await asyncio.gather(*tasks)
- index2url = {idx + 1: url for idx, url in enumerate(flatten_urls)}
- for support in grounding_supports:
- indices: list[int] = support.get("grounding_chunk_indices", [])
- indices_with_url = " ".join([f"[[{idx + 1}]]({index2url[idx + 1]})" for idx in indices])
- if segment := glom(support, "segment.text", default=""):
- answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
- for idx, grounding in enumerate(grounding_chunks):
- if idx > 9:
- break
- title = glom(grounding, "web.title", default="Web")
- url = flatten_urls[idx]
- if url in answers:
- answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
- return answers
-
-
-def gemini_logging(contexts: list):
- """Print logs of gemini contexts."""
- msg = ""
- with contextlib.suppress(Exception):
- for item in contexts:
- if isinstance(item, str):
- msg += f"{item}\n"
- continue
- if isinstance(item, types.File):
- msg += f"[{item.mime_type}]: {item.name}\n"
- continue
- if not isinstance(item, dict):
- continue
- role = item.get("role", "").upper() or "MODEL"
-
- # Request
- for part in item.get("parts", []):
- if part.inline_data:
- msg += f"[{role}]: Blob_Data "
- if part.text:
- msg += f"[{role}]: {part.text} "
- # Response
- if item.get("text", ""):
- msg += f"[{role}]: {item['text']} "
- if item.get("inline_data", ""):
- msg += f"[{role}]: Blob_Data "
-
- logger.debug(f"{msg!r}")
-
-
-def openai_context_to_gemini(context: dict, *, keep_marks: bool = True) -> types.ContentUnionDict:
- r"""(Deprecated) Convert OpenAI context to Gemini format.
-
- Not needed anymore.
-
- Args:
- context (dict): {
- "role": role, # assistant or user
- "content": [
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,encoding"}},
- {"type": "text", "text": "[username]: Bob\n[filename]: sample.txt\n[file content]:\nhello"}
- ]
- }
-
- Returns:
- dict: {
- "role": role, # model or user
- "parts: [
- {"inlineData": {"mimeType": "image/jpeg", "data": "base64-encoded string"}},
- {"text": "hello"}
- ]
- }
- """
- parts: list[types.Part] = []
- role = "model" if context["role"] == "assistant" else "user"
- for item in context["content"]:
- if item["type"] == "text":
- if keep_marks:
- parts.append(types.Part.from_text(text=item["text"]))
- else:
- parts.append(types.Part.from_text(text=clean_source_marks(item["text"])))
- elif item["type"] == "image_url":
- data = item["image_url"]["url"].split(";base64,")
- mime = data[0].removeprefix("data:")
- parts.append(types.Part.from_bytes(mime_type=mime, data=data[1]))
- return {"role": role, "parts": parts} # type: ignore
src/llm/gpt.py
@@ -1,25 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import os
-from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, TID, cache
+from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, cache
from llm.contexts import get_conversation_contexts, get_conversations
-from llm.gemini import HELP as AIGC_HELP
-from llm.gemini import gemini_response
-from llm.models import get_context_type, get_gpt_config, get_model_id
-from llm.response import send_to_gpt
+from llm.gemini.chat import gemini_chat_completion
+from llm.models import get_gpt_config, get_model_id
from llm.response_stream import send_to_gpt_stream
+from llm.text2img import TEXT2IMG_HELP, text2img
from llm.tools import merge_tools_response
-from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files, raw_reasoning
+from llm.utils import BOT_TIPS, clean_cmd_prefix, image_emoji, llm_cleanup_files
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
-from messages.utils import count_without_entities, equal_prefix, startswith_prefix
-from utils import slim_cid, strings_list, true
+from messages.utils import count_without_entities, equal_prefix
HELP = f"""🤖**GPT对话**
`{PREFIX.GPT}` 后接提示词即可与GPT对话
@@ -43,80 +39,11 @@ HELP = f"""🤖**GPT对话**
"""
-def is_gpt_conversation(minfo: dict) -> bool:
- # to avoid potential infinitely loop,
- # we do not respond to bot message & GPT responses.
- if minfo["is_bot"]:
- return False
- if BOT_TIPS in minfo["text"]:
- return False
-
- """Customization via Environment Variables.
- Useful for running multiple bots in a same chat. (Multiple LLM providers)
-
- GPT_{cid}_BAN_{uid}=1 : Ban user for using AI chat
- GPT_{cid}_ALLOW_USERS={uids} : Only allow users (comma separated userid) for using AI chat.
- GPT_{cid}_IGNORE_REPLY=1 : Ignore messages that is replying to another message
- GPT_{cid}_IGNORE_PREFIX=/gpt,/ds : Ignore prefix for specific chat ids
- """
- cid = slim_cid(minfo["cid"])
- if (uids := os.getenv(f"GPT_{cid}_ALLOW_USERS")) and str(minfo["uid"]) not in strings_list(uids):
- return False
- if true(os.getenv(f"GPT_{cid}_BAN_{minfo['uid']}")):
- return False
- if true(os.getenv(f"GPT_{cid}_IGNORE_REPLY")) and minfo["reply_mid"]:
- return False
- if startswith_prefix(minfo["text"], prefix=os.getenv(f"GPT_{cid}_IGNORE_PREFIX", "")):
- return False
-
- # starts with /prefix
- if startswith_prefix(minfo["text"], prefix=[PREFIX.GPT, PREFIX.GENIMG]):
- return True
-
- # not starts with /prefix, but in specific chat ids
- if any(str(x) in strings_list(TID.OPENAI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/gpt " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.GEMINI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/gemini " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.GROK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/grok " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.DEEPSEEK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/ds " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.QWEN_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/qwen " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.DOUBAO_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/doubao " + minfo["text"]
- return True
- if any(str(x) in strings_list(TID.KIMI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
- minfo["text"] = "/kimi " + minfo["text"]
- return True
-
- # is replying to gpt-bot response message?
- model_names = [
- GPT.OPENAI_MODEL_NAME,
- GPT.DEEPSEEK_MODEL_NAME,
- GPT.QWEN_MODEL_NAME,
- GPT.DOUBAO_MODEL_NAME,
- GPT.GROK_MODEL_NAME,
- GEMINI.TEXT_MODEL_NAME,
- GEMINI.IMG_MODEL_NAME,
- ]
- return startswith_prefix(minfo["reply_text"], prefix=[f"🤖{x}".lower() for x in model_names])
-
-
async def gpt_response(
client: Client,
message: Message,
*,
- gpt_stream: bool = True,
- system_prompt: str | None = None,
- enable_gpt_tools: bool = True,
- enable_gemini_tools: bool = True,
+ enable_tools: bool = True,
**kwargs,
) -> dict:
"""Get GPT response from Various API.
@@ -125,9 +52,7 @@ async def gpt_response(
client (Client): The Pyrogram client.
message (Message): The trigger message object.
gpt_stream (bool): Whether to use stream mode.
- system_prompt (str | None): System prompt.
- use_gpt_tools (bool): can use GPT tools.
- use_gemini_tools (bool): can use Gemini tools.
+ enable_tools (bool): use tools.
Returns:
dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
@@ -142,9 +67,10 @@ async def gpt_response(
info["uid"] = info["reply_uid"]
info["full_name"] = info["reply_full_name"]
if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GENIMG) and not message.reply_to_message:
- await send2tg(client, message, texts=AIGC_HELP, **kwargs)
+ await send2tg(client, message, texts=TEXT2IMG_HELP, **kwargs)
return {}
- if not is_gpt_conversation(info):
+ model_id, resp_modality = get_model_id(info, message)
+ if not model_id:
return {}
# cache media_group message, only process once
if media_group_id := message.media_group_id:
@@ -152,36 +78,21 @@ async def gpt_response(
return {}
cache.set(f"gpt-{info['cid']}-{media_group_id}", "1", ttl=120)
kwargs["message_info"] = info # save trigger message info
- conversations = get_conversations(message)
- context_type = get_context_type(conversations) # {"type": "text", "error": None} # text, image
- model_id, resp_modality = get_model_id(info["text"], info["reply_text"], context_type)
- if "gemini" in model_id.lower():
- return await gemini_response(
- client,
- message,
- conversations,
- resp_modality,
- system_prompt=system_prompt,
- enable_gemini_tools=enable_gemini_tools,
- **kwargs,
- )
+ if resp_modality == "image":
+ return await text2img(client, message, enable_tools=enable_tools, **kwargs)
+ if model_id == GEMINI.TEXT_MODEL:
+ return await gemini_chat_completion(client, message, enable_tools=enable_tools, **kwargs)
# GPT models
config = get_gpt_config(model_id)
- if not config["client"]["api_key"].strip():
- await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置API Key, 请尝试其他命令\n\n{HELP}", **kwargs)
- return {}
- if not config["completions"]["model"].strip():
- await send2tg(client, message, texts=f"⚠️**{config['friendly_name']}** 未配置模型ID, 请尝试其他命令\n\n{HELP}", **kwargs)
- return {}
-
+ conversations = get_conversations(message)
config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
- if enable_gpt_tools:
+ if enable_tools:
config, response = await merge_tools_response(config, **kwargs)
# skip send a new request if tool_model is the same as the current model
if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
@@ -205,25 +116,6 @@ async def gpt_response(
await modify_progress(message=status_msg, del_status=True, **kwargs)
llm_cleanup_files(config["completions"]["messages"])
return final
- final = {}
- if not gpt_stream:
- response = await send_to_gpt(config, **kwargs)
- if content := response.get("content"):
- if reasoning := response.get("reasoning"):
- final["thoughts"] = raw_reasoning(reasoning)
- content = f"{reasoning}\n{content}"
- texts = f"🤖**{response['model']}**:{BOT_TIPS}\n{content}"
- else:
- texts = f"🤖**{response['model']}**:{BOT_TIPS}\n\n{content}"
- logger.debug(texts)
- final |= {
- "texts": content,
- "prefix": f"🤖**{response['model']}**:{BOT_TIPS}\n",
- "model_name": config["friendly_name"],
- "sent_messages": await send2tg(client, message, texts=texts, **kwargs),
- }
- await modify_progress(message=status_msg, del_status=True, **kwargs)
- else:
- final = await send_to_gpt_stream(client, status_msg, config, system_prompt=system_prompt, **kwargs) # type: ignore
+ final = await send_to_gpt_stream(client, status_msg, config, **kwargs) # type: ignore
llm_cleanup_files(config["completions"]["messages"])
return final
src/llm/models.py
@@ -1,103 +1,62 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import os
from openai import DefaultAsyncHttpxClient
from pyrogram.types import Message
-from config import GEMINI, GPT, PREFIX, PROXY
-from llm.utils import sample_key
+from config import GEMINI, GPT, PREFIX, PROXY, TID
+from llm.contexts import get_conversations
+from llm.utils import BOT_TIPS, enabled_providers, sample_key
from messages.parser import parse_msg
from messages.utils import startswith_prefix
+from utils import slim_cid, strings_list, true
def get_context_type(conversations: list[Message]) -> str:
- """Get model type based on conversation messages."""
+ """Get model type based on message conversations."""
context_type = "text"
- for message in conversations:
- info = parse_msg(message, silent=True)
+ for msg in conversations:
+ info = parse_msg(msg, silent=True)
if info["mtype"] == "photo":
context_type = "image"
if info["mtype"] in ["video", "audio", "voice"]:
- context_type = "gemini" # only Gemini supports audio/video
+ context_type = info["mtype"]
return context_type
-def get_model_id(text: str, reply_text: str, context_type: str) -> tuple[str, str]:
- """Get model id based on the reply text, prefix command and context type.
+def get_model_id(minfo: dict, message: Message) -> tuple[str, str]:
+ """Get model id with response modality.
Returns:
- tuple[str, str, str]: (model_id, response_modality)
+ (model_id, response_modality)
"""
- model_id = ""
- # Parse from reply bot message.
- # For example, reply to DeepSeek bot message, use DeepSeek model.
- if reply_text.startswith(f"🤖{GPT.OPENAI_MODEL_NAME}"):
- model_id = GPT.OPENAI_MODEL
- elif reply_text.startswith(f"🤖{GPT.DEEPSEEK_MODEL_NAME}"):
- model_id = GPT.DEEPSEEK_MODEL
- elif reply_text.startswith(f"🤖{GPT.QWEN_MODEL_NAME}"):
- model_id = GPT.QWEN_MODEL
- elif reply_text.startswith(f"🤖{GPT.DOUBAO_MODEL_NAME}"):
- model_id = GPT.DOUBAO_MODEL
- elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
- model_id = GPT.GROK_MODEL
- elif reply_text.startswith(f"🤖{GPT.KIMI_MODEL_NAME}"):
- model_id = GPT.KIMI_MODEL
- elif reply_text.startswith(f"🤖{GEMINI.TEXT_MODEL_NAME}"):
- model_id = GEMINI.TEXT_MODEL
- elif reply_text.startswith(f"🤖{GEMINI.IMG_MODEL_NAME}"):
- model_id = GEMINI.IMG_MODEL
-
- # map providers to model_ids
- providers = {
- "openai": GPT.OPENAI_MODEL,
- "deepseek": GPT.DEEPSEEK_MODEL,
- "qwen": GPT.QWEN_MODEL,
- "doubao": GPT.DOUBAO_MODEL,
- "grok": GPT.GROK_MODEL,
- "gemini": GEMINI.TEXT_MODEL,
- "kimi": GPT.KIMI_MODEL,
- }
- # parse from command prefix. If use /ds command, force use DeepSeek model.
- if startswith_prefix(text, prefix="/gpt"):
- model_id = GPT.OPENAI_MODEL
- elif startswith_prefix(text, prefix="/ds"):
- model_id = GPT.DEEPSEEK_MODEL
- elif startswith_prefix(text, prefix="/qwen"):
- model_id = GPT.QWEN_MODEL
- elif startswith_prefix(text, prefix="/doubao"):
- model_id = GPT.DOUBAO_MODEL
- elif startswith_prefix(text, prefix="/grok"):
- model_id = GPT.GROK_MODEL
- elif startswith_prefix(text, prefix="/kimi"):
- model_id = GPT.KIMI_MODEL
- elif startswith_prefix(text, prefix=PREFIX.GENIMG):
- model_id = GEMINI.IMG_MODEL
- elif startswith_prefix(text, prefix="/gemini"):
- model_id = GEMINI.TEXT_MODEL
- else:
- model_id = providers.get(GPT.DEFAULT_PROVIDER.lower(), GPT.OPENAI_MODEL)
-
- # fallback to omni model if needed
- omni_providers = {
- "openai": "/gpt",
- "deepseek": "/ds",
- "qwen": "/qwen",
- "doubao": "/doubao",
- "grok": "/grok",
- "gemini": "/gemini",
- "kimi": "/kimi",
- }
- if model_id and (model_id == GEMINI.IMG_MODEL or reply_text.startswith(f"🤖{GEMINI.IMG_MODEL_NAME}")):
- response_modality = "image"
- elif "gemini" in model_id:
- response_modality = "text"
- else:
- response_modality = "text"
- if model_id and context_type == "text": # no need to fallback if context type is text
+ # to avoid potential infinitely loop,
+ # we do not respond to bot message & GPT responses.
+ if minfo["is_bot"]:
+ return "", ""
+ if BOT_TIPS in minfo["text"]:
+ return "", ""
+
+ model_id, response_modality = get_model_id_from_envars(minfo)
+ if model_id:
+ return model_id, response_modality
+
+ model_id, response_modality = get_model_id_from_prefix(minfo)
+ if not model_id:
+ return "", ""
+
+ # early return for non-text generation
+ if response_modality != "text":
+ return model_id, response_modality
+
+ # check if we need to fallback to omini model
+ conversations = get_conversations(message)
+ context_type = get_context_type(conversations) # {"type": "text", "error": None} # text, image
+ if context_type == "text": # no need to fallback if context type is text
return model_id, response_modality
- if context_type == "gemini": # force gemini
+ if context_type in ["video", "audio", "voice"]: # currently, only Gemini supports audio/video
return GEMINI.TEXT_MODEL, "text"
if (
@@ -108,12 +67,122 @@ def get_model_id(text: str, reply_text: str, context_type: str) -> tuple[str, st
or (model_id == GPT.GROK_MODEL and not GPT.GROK_ACCEPT_IMAGE)
or (model_id == GPT.KIMI_MODEL and not GPT.KIMI_ACCEPT_IMAGE)
):
- prefix = omni_providers.get(GPT.OMNI_PROVIDER.lower(), "/gpt")
- return get_model_id(prefix, reply_text, context_type) # parse again
+ omni_providers = {
+ "openai": GPT.OPENAI_MODEL,
+ "deepseek": GPT.DEEPSEEK_MODEL,
+ "qwen": GPT.QWEN_MODEL,
+ "doubao": GPT.DOUBAO_MODEL,
+ "grok": GPT.GROK_MODEL,
+ "gemini": GEMINI.TEXT_MODEL,
+ "kimi": GPT.KIMI_MODEL,
+ }
+ text_providers, _ = enabled_providers()
+ # prefer gemini if OMNI_PROVIDER is not set
+ model_id = omni_providers.get(GPT.OMNI_PROVIDER.lower()) or GEMINI.TEXT_MODEL or omni_providers[text_providers[0]]
+ return model_id, "text"
return model_id, response_modality
+def get_model_id_from_envars(minfo: dict) -> tuple[str, str]:
+ """Useful for running multiple bots in a same chat.
+
+ GPT_{cid}_BAN_{uid}=1 : Ban user for using AI chat
+ GPT_{cid}_ALLOW_USERS={uids} : Only allow users (comma separated userid) for using AI chat.
+ GPT_{cid}_IGNORE_REPLY=1 : Ignore messages that is replying to another message
+ GPT_{cid}_IGNORE_PREFIX=/gpt,/ds : Ignore prefix for specific chat ids
+
+ Returns:
+ (model_id, response_modality)
+ """
+ cid = slim_cid(minfo["cid"])
+ if (uids := os.getenv(f"GPT_{cid}_ALLOW_USERS")) and str(minfo["uid"]) not in strings_list(uids):
+ return "", ""
+ if true(os.getenv(f"GPT_{cid}_BAN_{minfo['uid']}")):
+ return "", ""
+ if true(os.getenv(f"GPT_{cid}_IGNORE_REPLY")) and minfo["reply_mid"]:
+ return "", ""
+ if startswith_prefix(minfo["text"], prefix=os.getenv(f"GPT_{cid}_IGNORE_PREFIX", "")):
+ return "", ""
+
+ # not starts with /prefix, but in specific chat ids
+ if any(str(x) in strings_list(TID.OPENAI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/gpt " + minfo["text"]
+ return GPT.OPENAI_MODEL, "text"
+ if any(str(x) in strings_list(TID.GEMINI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/gemini " + minfo["text"]
+ return GEMINI.TEXT_MODEL, "text"
+ if any(str(x) in strings_list(TID.GROK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/grok " + minfo["text"]
+ return GPT.GROK_MODEL, "text"
+ if any(str(x) in strings_list(TID.DEEPSEEK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/ds " + minfo["text"]
+ return GPT.DEEPSEEK_MODEL, "text"
+ if any(str(x) in strings_list(TID.QWEN_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/qwen " + minfo["text"]
+ return GPT.QWEN_MODEL, "text"
+ if any(str(x) in strings_list(TID.DOUBAO_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/doubao " + minfo["text"]
+ return GPT.DOUBAO_MODEL, "text"
+ if any(str(x) in strings_list(TID.KIMI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+ minfo["text"] = "/kimi " + minfo["text"]
+ return GPT.KIMI_MODEL, "text"
+ return "", ""
+
+
+def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
+ text_providers, img_providers = enabled_providers()
+ if startswith_prefix(minfo["text"], prefix="/gpt") and "openai" in text_providers:
+ return GPT.OPENAI_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/gemini") and "gemini" in text_providers:
+ return GEMINI.TEXT_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/ds") and "deepseek" in text_providers:
+ return GPT.DEEPSEEK_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/doubao") and "doubao" in text_providers:
+ return GPT.DOUBAO_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/qwen") and "qwen" in text_providers:
+ return GPT.QWEN_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/kimi") and "kimi" in text_providers:
+ return GPT.KIMI_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix="/grok") and "grok" in text_providers:
+ return GPT.GROK_MODEL, "text"
+ if startswith_prefix(minfo["text"], prefix=PREFIX.GENIMG) and "gemini" in img_providers:
+ return GEMINI.IMG_MODEL, "image"
+ # start with /ai, auto detect model_id
+ if startswith_prefix(minfo["text"], prefix="/ai") and text_providers:
+ providers = {
+ "openai": GPT.OPENAI_MODEL,
+ "deepseek": GPT.DEEPSEEK_MODEL,
+ "qwen": GPT.QWEN_MODEL,
+ "doubao": GPT.DOUBAO_MODEL,
+ "grok": GPT.GROK_MODEL,
+ "gemini": GEMINI.TEXT_MODEL,
+ "kimi": GPT.KIMI_MODEL,
+ }
+ # prefer gemini if DEFAULT_PROVIDER is not set
+ model_id = providers.get(GPT.DEFAULT_PROVIDER.lower()) or GEMINI.TEXT_MODEL or providers[text_providers[0]]
+ return model_id, "text"
+
+ # is replying to AI response message
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.OPENAI_MODEL_NAME}:{BOT_TIPS}") and "openai" in text_providers:
+ return GPT.OPENAI_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.TEXT_MODEL_NAME}:{BOT_TIPS}") and "gemini" in text_providers:
+ return GEMINI.TEXT_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DEEPSEEK_MODEL_NAME}:{BOT_TIPS}") and "deepseek" in text_providers:
+ return GPT.DEEPSEEK_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.DOUBAO_MODEL_NAME}:{BOT_TIPS}") and "doubao" in text_providers:
+ return GPT.DOUBAO_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.QWEN_MODEL_NAME}:{BOT_TIPS}") and "qwen" in text_providers:
+ return GPT.QWEN_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.KIMI_MODEL_NAME}:{BOT_TIPS}") and "kimi" in text_providers:
+ return GPT.KIMI_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GPT.GROK_MODEL_NAME}:{BOT_TIPS}") and "grok" in text_providers:
+ return GPT.GROK_MODEL, "text"
+ if startswith_prefix(minfo["reply_text"], prefix=f"🤖{GEMINI.IMG_MODEL_NAME}:{BOT_TIPS}") and "gemini" in img_providers:
+ return GEMINI.IMG_MODEL, "image"
+ return "", ""
+
+
def get_gpt_config(model_id: str = "") -> dict:
"""Get GPT configurations."""
model_factory = {
src/llm/summary.py
@@ -186,7 +186,7 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
client,
ai_msg,
system_prompt=SYSTEM_PROMPT,
- enable_gpt_tools=False,
+ enable_tools=False,
include_thoughts=False,
append_grounding=False,
silent=True,
src/llm/text2img.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from collections import defaultdict
+
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import GEMINI, PREFIX, TEXT2IMG
+from llm.ali.text2img import ali_text2img
+from llm.cloudflare.text2img import cloudflare_text2img
+from llm.gemini.text2img import gemini_text2img
+from llm.utils import enabled_providers
+from utils import strings_list
+
+TEXT2IMG_HELP = f"""🌠**AI生图**
+`{PREFIX.GENIMG}` 后接提示词即可生成
+
+⚙️模型配置:
+- `{PREFIX.GENIMG}`: 默认模型 (**{GEMINI.IMG_MODEL}**)
+- `{PREFIX.GENIMG} @flux`: Flux模型
+- `{PREFIX.GENIMG} @sd`: Stable Diffusion模型
+
+对于Gemini模型可通过回复消息把历史图片加入上下文, 继续对话以重新修改生成结果
+"""
+
+
+async def text2img(client: Client, message: Message, **kwargs) -> dict:
+ """Text to image generation.
+
+ Args:
+ client (Client): The Pyrogram client.
+ message (Message): The trigger message object.
+ system_prompt (str | None, optional): System prompt. Defaults to None.
+
+ Returns:
+ dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
+ """
+ texts = message.content.removeprefix(PREFIX.GENIMG).strip()
+ all_models = enabled_models()
+ if not all_models:
+ return {}
+ categories = list(all_models) # ['gemini', 'flux', 'sd']
+ models = all_models.get(TEXT2IMG.DEFAULT_MODEL, [])
+ prompt = texts
+ for category in categories:
+ if texts.lower().startswith(f"@{category}"):
+ models = all_models[category]
+ prompt = texts.removeprefix(f"@{category}").strip()
+ break
+
+ for model in models:
+ provider, model_id = model.split("/", 1)
+ try:
+ if provider == "gemini":
+ return await gemini_text2img(client, message, **kwargs)
+ if provider == "ali":
+ return await ali_text2img(client, message, model_id, prompt, **kwargs)
+ if provider == "cloudflare":
+ return await cloudflare_text2img(client, message, model_id, prompt, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ return {}
+
+
+def enabled_models() -> dict[str, list]:
+ """Get all enabled text to image generation model ids.
+
+ model_id format: {provider}/{real_model_id}
+
+ Returns:
+ dict[str,list]: {
+ "gemini": ["gemini/gemini-2.0-flash"],
+ "flux": ["ali/flux-dev", "cloudflare/@cf/black-forest-labs/flux-1-schnell],
+ "sd": ["ali/stable-diffusion-3.5-large", "cloudflare/@cf/bytedance/stable-diffusion-xl-lightning"]}
+ """
+ models = defaultdict(list)
+ _, img_providers = enabled_providers()
+ for provider in img_providers:
+ if provider == "gemini":
+ models["gemini"] = [f"gemini/{GEMINI.IMG_MODEL}"]
+ if provider == "ali" and TEXT2IMG.ALI_FLUX_MODEL and "ali" in strings_list(TEXT2IMG.FLUX_PROVIDER):
+ models["flux"].extend([f"ali/{model}" for model in strings_list(TEXT2IMG.ALI_FLUX_MODEL)])
+ if provider == "ali" and TEXT2IMG.ALI_STABLE_DIFFUSION_MODEL and "ali" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
+ models["sd"].extend([f"ali/{model}" for model in strings_list(TEXT2IMG.ALI_STABLE_DIFFUSION_MODEL)])
+ if provider == "cloudflare" and TEXT2IMG.CF_FLUX_MODEL and "cloudflare" in strings_list(TEXT2IMG.FLUX_PROVIDER):
+ models["flux"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_FLUX_MODEL)])
+ if provider == "cloudflare" and TEXT2IMG.CF_STABLE_DIFFUSION_MODEL and "cloudflare" in strings_list(TEXT2IMG.STABLE_DIFFUSION_PROVIDER):
+ models["sd"].extend([f"cloudflare/{model}" for model in strings_list(TEXT2IMG.CF_STABLE_DIFFUSION_MODEL)])
+ return models
src/llm/utils.py
@@ -14,7 +14,7 @@ from loguru import logger
from markitdown import MarkItDown
from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
-from config import DOWNLOAD_DIR, GEMINI, GPT, PREFIX
+from config import DOWNLOAD_DIR, GEMINI, GPT, PREFIX, TEXT2IMG
from utils import nowdt, number_to_emoji, read_text, remove_consecutive_newlines, remove_dash, remove_pound, strings_list, zhcn
BOT_TIPS = "(回复以继续)" # noqa: RUF001
@@ -22,6 +22,38 @@ REASONING_BEGIN = "🤔" # use emoji to separate model reasoning and content
REASONING_END = "💡"
+def enabled_providers() -> tuple[list[str], list[str]]:
+ """Get enabled providers.
+
+ Returns:
+ (text_providers, img_providers)
+ """
+ text_providers = []
+ if all([GPT.OPENAI_MODEL, GPT.OPENAI_MODEL_NAME, GPT.OPENAI_API_KEY, GPT.OPENAI_BASE_URL]):
+ text_providers.append("openai")
+ if all([GPT.GROK_MODEL, GPT.GROK_MODEL_NAME, GPT.GROK_API_KEY, GPT.GROK_BASE_URL]):
+ text_providers.append("grok")
+ if all([GPT.DEEPSEEK_MODEL, GPT.DEEPSEEK_MODEL_NAME, GPT.DEEPSEEK_API_KEY, GPT.DEEPSEEK_BASE_URL]):
+ text_providers.append("deepseek")
+ if all([GPT.QWEN_MODEL, GPT.QWEN_MODEL_NAME, GPT.QWEN_API_KEY, GPT.QWEN_BASE_URL]):
+ text_providers.append("qwen")
+ if all([GPT.DOUBAO_MODEL, GPT.DOUBAO_MODEL_NAME, GPT.DOUBAO_API_KEY, GPT.DOUBAO_BASE_URL]):
+ text_providers.append("doubao")
+ if all([GPT.KIMI_MODEL, GPT.KIMI_MODEL_NAME, GPT.KIMI_API_KEY, GPT.KIMI_BASE_URL]):
+ text_providers.append("kimi")
+ if all([GEMINI.API_KEY, GEMINI.BASE_URL, GEMINI.TEXT_MODEL, GEMINI.TEXT_MODEL_NAME, GEMINI.TEXT_TOKENS_FALLBACK_MODEL]):
+ text_providers.append("gemini")
+
+ img_providers = []
+ if all([GEMINI.API_KEY, GEMINI.BASE_URL, GEMINI.IMG_MODEL, GEMINI.IMG_MODEL_NAME]):
+ img_providers.append("gemini")
+ if all([TEXT2IMG.ALI_API_KEY]):
+ img_providers.append("ali")
+ if all([TEXT2IMG.CF_API_KEY]):
+ img_providers.append("cloudflare")
+ return text_providers, img_providers
+
+
def llm_cleanup_files(messages: list[dict]):
"""Clean downloaded files.
src/config.py
@@ -76,7 +76,7 @@ class PREFIX:
ASR = os.getenv("PREFIX_ASR", "/asr").lower()
AUDIO = os.getenv("PREFIX_AUDIO", "/audio").lower()
CONVERT = os.getenv("PREFIX_CONVERT", "/convert").lower() # convert image file to photo
- GPT = os.getenv("PREFIX_GPT", "/ai,/gpt,/gemini,/ds,/qwen,/doubao,/grok,/kimi").lower()
+ GPT = "/ai,/gpt,/gemini,/ds,/qwen,/doubao,/grok,/kimi" # this is fixed
SUBTITLE = os.getenv("PREFIX_SUBTITLE", "/subtitle, /sub").lower()
WGET = os.getenv("PREFIX_WGET", "/wget, /curl").lower()
OCR = os.getenv("PREFIX_OCR", "/ocr").lower()
@@ -356,12 +356,12 @@ class GPT:
OPENROUTER_FALLBACK_MODELS = os.getenv("GPT_OPENROUTER_FALLBACK_MODELS", "")
# default command (/ai).
- # set a string contains "gemini" to switch to gemini (see class GEMINI below for details)
- DEFAULT_PROVIDER = os.getenv("GPT_DEFAULT_PROVIDER", "openai")
+ # "gemini" to switch to gemini (see class GEMINI below for details)
+ DEFAULT_PROVIDER = os.getenv("GPT_DEFAULT_PROVIDER", "gemini")
# omni provider (this should be a multi-modality model, like gpt-4o.)
# Used when the contexts contain multi-modelity data (text, image), but other model can not handle it.
# For example, /ds command can only handle text, but the contexts contains images.
- OMNI_PROVIDER = os.getenv("GPT_OMNI_PROVIDER", "openai")
+ OMNI_PROVIDER = os.getenv("GPT_OMNI_PROVIDER", "gemini")
# /gpt command
OPENAI_MODEL = os.getenv("GPT_OPENAI_MODEL", "")
@@ -440,3 +440,17 @@ class GEMINI: # Official Gemini
ASR_THINKING_BUDGET = os.getenv("GEMINI_ASR_THINKING_BUDGET", None) # 0 to disable thinking. DO NOT set this if the model is not a thinking model
ASR_CONFIG = os.getenv("GEMINI_ASR_CONFIG", "{}") # default config passed to GenerateContentConfig. Should be a json string: '{"key": "value"}'
ASR_USE_GROUNDING = os.getenv("GEMINI_ASR_USE_GROUNDING", "1").lower() in ["1", "y", "yes", "t", "true", "on"] # Use Grounding with Google Search
+
+
+class TEXT2IMG:
+ DEFAULT_MODEL = os.getenv("TEXT2IMG_DEFAULT_MODEL", "gemini")
+ FLUX_PROVIDER = os.getenv("TEXT2IMG_FLUX_PROVIDER", "ali,cloudflare").lower() # comma separated
+ STABLE_DIFFUSION_PROVIDER = os.getenv("TEXT2IMG_STABLE_DIFFUSION_PROVIDER", "ali,cloudflare").lower()
+ ALI_API_KEY = os.getenv("TEXT2IMG_ALI_API_KEY", "")
+ ALI_FLUX_MODEL = os.getenv("TEXT2IMG_ALI_FLUX_MODEL", "flux-dev")
+ ALI_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_ALI_STABLE_DIFFUSION_MODEL", "stable-diffusion-3.5-large")
+ ALI_PROXY = os.getenv("TEXT2IMG_ALI_PROXY", None)
+ CF_API_KEY = os.getenv("TEXT2IMG_CF_API_KEY", "") # comma separated keys. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
+ CF_FLUX_MODEL = os.getenv("TEXT2IMG_CF_FLUX_MODEL", "@cf/black-forest-labs/flux-1-schnell")
+ CF_STABLE_DIFFUSION_MODEL = os.getenv("TEXT2IMG_CF_STABLE_DIFFUSION_MODEL", "@cf/bytedance/stable-diffusion-xl-lightning")
+ CF_PROXY = os.getenv("TEXT2IMG_CF_PROXY", None)