Commit 10c4aca
Changed files (10)
src
others
subtitles
src/llm/gemini/chat.py
@@ -25,6 +25,8 @@ async def gemini_chat_completion(
client: Client,
message: Message,
*,
+ model_id: str = GEMINI.TEXT_MODEL,
+ model_name: str = GEMINI.TEXT_MODEL_NAME,
enable_tools: bool = True,
append_grounding: bool = True,
disable_thinking: bool = False,
@@ -53,8 +55,8 @@ async def gemini_chat_completion(
extra_config_str = GEMINI.TEXT_CONFIG
genconfig = json.loads(extra_config_str)
try:
- real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
- msg = f"🤖**{GEMINI.TEXT_MODEL_NAME}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
+ real_prompt = clean_cmd_prefix(info["text"], model_id) or clean_cmd_prefix(info["reply_text"], model_id)
+ msg = f"🤖**{model_name}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
if not silent and kwargs.get("show_progress"):
kwargs["progress"] = (await send2tg(client, message, texts=msg, **kwargs))[0]
genconfig |= {"response_modalities": ["TEXT"]}
@@ -68,9 +70,9 @@ async def gemini_chat_completion(
if GEMINI.TEXT_THINKING_BUDGET is not None and not disable_thinking:
thinking_budget = min(round(float(GEMINI.TEXT_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
genconfig |= {"thinking_config": types.ThinkingConfig(include_thoughts=include_thoughts, thinking_budget=thinking_budget)}
- params = {"model": GEMINI.TEXT_MODEL, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
+ params = {"model": model_id, "conversations": get_conversations(message), "config": types.GenerateContentConfig(**genconfig)}
logger.trace(params)
- return await gemini_stream(client, message, GEMINI.TEXT_MODEL_NAME, params, append_grounding=append_grounding, silent=silent, **kwargs)
+ return await gemini_stream(client, message, model_name, params, append_grounding=append_grounding, silent=silent, **kwargs)
except Exception as e:
logger.error(e)
return {}
@@ -126,7 +128,7 @@ async def gemini_stream(
app = genai.Client(api_key=api_key, http_options=http_options)
# Construct the request params
if "conversations" in params: # convert conversations to contents
- params["contents"] = await get_conversation_contexts(client, params["conversations"], ctx_format="gemini", app=app)
+ params["contents"] = await get_conversation_contexts(client, params["conversations"], model_id=params["model"], ctx_format="gemini", app=app)
gemini_logging(params["contents"])
tokens = await app.aio.models.count_tokens(model=params["model"], contents=params["contents"]) # type: ignore
num_tokens = tokens.total_tokens or 0
src/llm/gemini/text2img.py
@@ -89,7 +89,7 @@ async def gemini_non_stream(
app = genai.Client(api_key=api_key, http_options=http_options)
# Construct the request params
if "conversations" in params: # convert conversations to contents
- params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), ctx_format="gemini", app=app)
+ params["contents"] = await get_conversation_contexts(client, params.pop("conversations"), model_id=params["model"], ctx_format="gemini", app=app)
clean_gemini_sourcemarks(params["contents"])
genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
response = await app.aio.models.generate_content(**genai_params)
src/llm/contexts.py
@@ -36,6 +36,7 @@ def get_conversations(message: Message) -> list[Message]:
async def get_conversation_contexts(
client: Client,
conversations: list[Message],
+ model_id: str = "",
ctx_format: str = "openai",
app: genai.Client | AsyncOpenAI | None = None,
) -> list[dict]:
@@ -45,16 +46,16 @@ async def get_conversation_contexts(
"""
# parse context for each message
if ctx_format.lower() == "openai":
- contexts = [await single_gpt_context(client, message) for message in conversations]
+ contexts = [await single_gpt_context(client, message, model_id) for message in conversations]
contexts = [x for x in contexts if x.get("content")]
else:
- contexts = [await single_gemini_context(client, message, app) for message in conversations] # type: ignore
+ contexts = [await single_gemini_context(client, message, app, model_id) for message in conversations] # type: ignore
contexts = [x for x in contexts if x.get("parts")]
return contexts[: int(GPT.HISTORY_CONTEXT)]
-async def single_gpt_context(client: Client, message: Message) -> dict:
+async def single_gpt_context(client: Client, message: Message, model_id: str = "") -> dict:
"""Generate GPT contexts for a single message (Without considering reply message).
Returns:
@@ -106,7 +107,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
}
)
# user message has entity urls, use full html
- clean_texts = clean_context(info["html"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
+ clean_texts = clean_context(info["html"], model_id) if role == "user" and info["entity_urls"] else clean_context(info["text"], model_id)
if not clean_texts:
continue
texts = f"[username]: {sender}\n[message]:\n{clean_texts}" if role == "user" and sender else clean_texts
@@ -117,7 +118,7 @@ async def single_gpt_context(client: Client, message: Message) -> dict:
return {"role": role, "content": contexts} if contexts else {}
-async def single_gemini_context(client: Client, message: Message, app: genai.Client) -> dict:
+async def single_gemini_context(client: Client, message: Message, app: genai.Client, model_id: str = "") -> dict:
"""Generate Gemini contexts for a single message (Without considering reply message).
Returns:
@@ -170,7 +171,7 @@ async def single_gemini_context(client: Client, message: Message, app: genai.Cli
Path(fpath).unlink(missing_ok=True)
parts.append(Part.from_text(text=f"[filename]: {info['file_name']}\n[file content]:\n{text.strip()}"))
# user message has entity urls, use full html
- clean_texts = clean_context(info["html"]) if role == "user" and info["entity_urls"] else clean_context(info["text"])
+ clean_texts = clean_context(info["html"], model_id) if role == "user" and info["entity_urls"] else clean_context(info["text"], model_id)
if not clean_texts:
continue
texts = f"[username]: {sender}\n[message]:\n{clean_texts}" if role == "user" and sender else clean_texts
src/llm/gpt.py
@@ -1,6 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import re
+
+from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
@@ -16,6 +19,7 @@ from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import count_without_entities, equal_prefix
+from utils import strings_list
HELP = f"""🤖**GPT对话**
`{PREFIX.GPT}` 后接提示词即可与GPT对话
@@ -43,16 +47,20 @@ async def gpt_response(
client: Client,
message: Message,
*,
+ custom_model_id: str = "",
enable_tools: bool = True,
**kwargs,
) -> dict:
"""Get GPT response from Various API.
+ `/ai text`: get response from LLM
+ `/ai @gemini-2.5-flash text`: get response from gemini-2.5-flash (custom model id)
+
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
- gpt_stream (bool): Whether to use stream mode.
- enable_tools (bool): use tools.
+ custom_model_id (str, optional): Custom model id.
+ enable_tools (bool, optional): Whether to enable tools. Defaults to True.
Returns:
dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
@@ -80,14 +88,28 @@ async def gpt_response(
kwargs["message_info"] = info # save trigger message info
if resp_modality == "image":
return await text2img(client, message, enable_tools=enable_tools, **kwargs)
- if model_id == GEMINI.TEXT_MODEL:
+
+ # handle custom model_id here
+ if matched := re.match(r"^/ai @([a-zA-Z0-9_\-\.]+)(\s+)?", info["text"]): # match /ai @custom_model_id
+ custom_model_id = matched.group(1).strip()
+ logger.warning(f"Custom model id: {custom_model_id}")
+ allowed_model_ids = [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS) + strings_list(GPT.ALLOWED_CUSTOM_MODEL_IDS)]
+ if custom_model_id and custom_model_id.lower() not in allowed_model_ids:
+ await send2tg(client, message, texts=f"⚠️不支持自定义模型: {custom_model_id}\n\n⚙️支持自定义模型列表:\n{'\n'.join(allowed_model_ids)}", **kwargs)
+ return {}
+ if custom_model_id.lower() in [x.lower() for x in strings_list(GEMINI.ALLOWED_CUSTOM_MODEL_IDS)]:
+ return await gemini_chat_completion(client, message, model_id=custom_model_id, model_name=custom_model_id, enable_tools=enable_tools, **kwargs)
+ if model_id == GEMINI.TEXT_MODEL and not custom_model_id:
return await gemini_chat_completion(client, message, enable_tools=enable_tools, **kwargs)
# GPT models
+ if custom_model_id:
+ model_id = custom_model_id
config = get_gpt_config(model_id)
+ config["friendly_name"] = custom_model_id or config["friendly_name"]
conversations = get_conversations(message)
- config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
- real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+ config["completions"]["messages"] = await get_conversation_contexts(client, conversations, model_id=model_id, ctx_format="openai")
+ real_prompt = clean_cmd_prefix(info["text"], model_id) or clean_cmd_prefix(info["reply_text"], model_id)
msg = f"🤖**{config['friendly_name']}**: 思考中...\n👤**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: “{real_prompt}”"[:TEXT_LENGTH]
status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
kwargs["progress"] = status_msg
src/llm/models.py
@@ -186,19 +186,24 @@ def get_model_id_from_prefix(minfo: dict) -> tuple[str, str]:
def get_gpt_config(model_id: str = "") -> dict:
"""Get GPT configurations."""
model_factory = {
- GPT.OPENAI_MODEL: {"api_key": sample_key(GPT.OPENAI_API_KEY), "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
- GPT.DEEPSEEK_MODEL: {"api_key": sample_key(GPT.DEEPSEEK_API_KEY), "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
- GPT.QWEN_MODEL: {"api_key": sample_key(GPT.QWEN_API_KEY), "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
- GPT.DOUBAO_MODEL: {"api_key": sample_key(GPT.DOUBAO_API_KEY), "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
- GPT.GROK_MODEL: {"api_key": sample_key(GPT.GROK_API_KEY), "base_url": GPT.GROK_BASE_URL, "model_name": GPT.GROK_MODEL_NAME},
- GPT.KIMI_MODEL: {"api_key": sample_key(GPT.KIMI_API_KEY), "base_url": GPT.KIMI_BASE_URL, "model_name": GPT.KIMI_MODEL_NAME},
+ "gpt,chatgpt,o1,o3,o4": {"api_key": sample_key(GPT.OPENAI_API_KEY), "base_url": GPT.OPENAI_BASE_URL, "model_name": GPT.OPENAI_MODEL_NAME},
+ "deepseek": {"api_key": sample_key(GPT.DEEPSEEK_API_KEY), "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
+ "qwen,qvq,qwq": {"api_key": sample_key(GPT.QWEN_API_KEY), "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
+ "doubao": {"api_key": sample_key(GPT.DOUBAO_API_KEY), "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
+ "grok": {"api_key": sample_key(GPT.GROK_API_KEY), "base_url": GPT.GROK_BASE_URL, "model_name": GPT.GROK_MODEL_NAME},
+ "kimi": {"api_key": sample_key(GPT.KIMI_API_KEY), "base_url": GPT.KIMI_BASE_URL, "model_name": GPT.KIMI_MODEL_NAME},
}
client = {"http_client": DefaultAsyncHttpxClient(proxy=PROXY.GPT)}
if GPT.TIMEOUT is not None:
client |= {"timeout": int(GPT.TIMEOUT)}
- model_id_config = model_factory.get(model_id, {})
+ model_id_config = {}
+ for prefix, config in model_factory.items():
+ if startswith_prefix(model_id, prefix):
+ model_id_config = config
+ break
+
model_name = model_id_config.get("model_name", "")
model_id_config.pop("model_name", None)
client |= model_id_config
src/llm/summary.py
@@ -176,15 +176,11 @@ async def ai_summary(client: Client, message: Message, summary_prefix: str | Non
return
await modify_progress(text=f"🤖AI总结中...\n{msg}", force_update=True, **kwargs)
# Construct a message to call GPT
- ai_msg = Message(
- id=rand_number(),
- chat=message.chat,
- text=Str(GPT.SUMMARY_CMD),
- reply_to_message=Message(id=rand_number(), chat=message.chat, text=Str(parsed["history"])),
- )
+ ai_msg = Message(id=0, chat=message.chat, text=Str(f"/ai {parsed['history']}"))
response = await gpt_response(
client,
ai_msg,
+ custom_model_id=GPT.CHAT_SUMMARY_MODEL_ID,
system_prompt=SYSTEM_PROMPT,
enable_tools=False,
include_thoughts=False,
src/llm/utils.py
@@ -209,9 +209,11 @@ def image_emoji(capability: bool) -> str: # noqa: FBT001
return "🏞" if capability else ""
-def clean_cmd_prefix(text: str) -> str:
+def clean_cmd_prefix(text: str, model_id: str = "") -> str:
for prefix in [*strings_list(PREFIX.GPT), PREFIX.GENIMG]:
text = text.removeprefix(prefix).lstrip()
+ if model_id:
+ text = text.removeprefix(f"@{model_id}").lstrip()
return text
@@ -229,11 +231,11 @@ def clean_reasoning(text: str) -> str:
return text.removeprefix(BLOCKQUOTE_EXPANDABLE_END_DELIM).lstrip()
-def clean_context(text: str) -> str:
+def clean_context(text: str, model_id: str = "") -> str:
"""Remove bot prefix and reasoning content."""
text = re.sub(r"^👤@.*?\/\/", "", text) # remove markdown send_from_user
text = re.sub(r"^👤\<a.*?tg://user\?id=\d+.*?@.*?</a>//", "", text) # remove html send_from_user
- text = clean_cmd_prefix(text)
+ text = clean_cmd_prefix(text, model_id)
text = clean_bot_tips(text)
return clean_reasoning(text)
src/others/podcast.py
@@ -19,7 +19,7 @@ from pyrogram.types import Chat, Message
from pyrogram.types.messages_and_media.message import Str
from asr.voice_recognition import asr_file
-from config import DB, DOWNLOAD_DIR, PODCAST, PREFIX, READING_SPEED, TZ, cache
+from config import DB, DOWNLOAD_DIR, GPT, PODCAST, READING_SPEED, TZ, cache
from database.alist import upload_alist
from database.r2 import get_cf_r2, set_cf_r2
from llm.gpt import gpt_response
@@ -100,16 +100,15 @@ async def summary_pods(client: Client):
prompt = f"这是播客栏目《{feed_title}》的一期节目详情:\n节目标题: {entry['title']}\n节目播出日期: {pubdate}"
prompt += f"\n节目时长: {readable_time(entry['itunes_duration'])}\n节目简介: {desc}"
prompt += "\n请解读该播客内容, 只需关注内容本身, 不用概述播客的基本信息, 例如播客的标题, 日期, 时长等"
- ai_cmd = next((x.strip() for x in PREFIX.GPT.split(",") if x.strip()), "")
# Construct a message to call GPT
cache.delete(f"parse_msg-{txt_msg.chat.id}-{txt_msg.id}")
ai_msg = Message(
id=txt_msg.id,
chat=txt_msg.chat,
- text=Str(f"{ai_cmd} {remove_img(prompt)}"),
+ text=Str(f"/ai {remove_img(prompt)}"),
reply_to_message=Message(id=rand_number(), chat=message.chat, text=Str(subtitles)),
)
- gpt_res = await gpt_response(client, ai_msg, include_thoughts=False, append_grounding=False, show_progress=True)
+ gpt_res = await gpt_response(client, ai_msg, custom_model_id=GPT.PODCAST_SUMMARY_MODEL_ID, include_thoughts=False, append_grounding=False, show_progress=True)
cache.delete(f"parse_msg-{txt_msg.chat.id}-{txt_msg.id}")
feed_item = match_item(feed_xml, entry)
update_item(saved_xml, feed_item, prefix_desc=gpt_res.get("texts", ""))
src/subtitles/subtitle.py
@@ -11,7 +11,7 @@ from pyrogram.types import Message
from pyrogram.types.messages_and_media.message import Str
from asr.voice_recognition import asr_file
-from config import ASR, DOWNLOAD_DIR, PREFIX, READING_SPEED, TEXT_LENGTH, cache
+from config import ASR, DOWNLOAD_DIR, GPT, PREFIX, READING_SPEED, TEXT_LENGTH, cache
from llm.gpt import gpt_response
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -119,16 +119,17 @@ async def get_subtitle(client: Client, message: Message, *, to_telegraph: bool =
if description.strip():
prompt += f"节目简介: {description}\n"
prompt += "\n请解读本期节目内容。要求: 直接输出节目内容解读, 以“该节目讲述了”开头"
- ai_cmd = next((x.strip() for x in PREFIX.GPT.split(",") if x.strip()), "")
# Construct a message to call GPT
ai_msg = Message(
id=subtitle_msg.id,
chat=subtitle_msg.chat,
- text=Str(f"{ai_cmd} {prompt}"),
+ text=Str(f"/ai {prompt}"),
reply_to_message=Message(id=rand_number(), chat=subtitle_msg.chat, text=Str(subtitles)),
)
- kwargs["include_thoughts"] = False
- await gpt_response(client, ai_msg, **kwargs)
+ kwargs |= {"include_thoughts": False, "append_grounding": False, "silent": True, "custom_model_id": GPT.SUBTITLE_SUMMARY_MODEL_ID}
+ res = await gpt_response(client, ai_msg, **kwargs)
+ if res.get("texts"):
+ await send2tg(client, ai_msg, texts=res["prefix"] + res["texts"], **kwargs)
with contextlib.suppress(Exception):
[await delete_message(msg) for msg in res.get("sent_messages", [])]
await delete_message(kwargs.get("progress"))
src/config.py
@@ -360,6 +360,7 @@ class GPT:
MAX_RETRY = int(os.getenv("GPT_MAX_RETRY", "2"))
HELICONE_API_KEY = os.getenv("HELICONE_API_KEY", "") # https://docs.helicone.ai/getting-started/integration-method/gateway
COLLAPSE_LENGTH = int(os.getenv("GPT_COLLAPSE_LENGTH", "500")) # Collapse the response if the length is larger than this value
+ ALLOWED_CUSTOM_MODEL_IDS = os.getenv("GPT_ALLOWED_CUSTOM_MODEL_IDS", "") # comma separated OpenAI compatible model ids
# comma separated fallback models for OpenRouter (e.g. openai/gpt-4o,anthropic/claude-3.5-sonnet)
OPENROUTER_FALLBACK_MODELS = os.getenv("GPT_OPENROUTER_FALLBACK_MODELS", "")
@@ -408,10 +409,12 @@ class GPT:
KIMI_BASE_URL = os.getenv("GPT_KIMI_BASE_URL", "https://api.moonshot.ai/v1")
KIMI_ACCEPT_IMAGE = os.getenv("GPT_KIMI_ACCEPT_IMAGE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
- # AI summary (/summary)
- SUMMARY_CMD = os.getenv("GPT_SUMMARY_CMD", "/gemini") # add this command prefix to call AI summary
+ # AI summary
# comma separated chat ids that are allowed to use `cid` as the chatid for the summary
SUMMARY_WHITELIST_CUSTOM_CHATS = os.getenv("GPT_SUMMARY_WHITELIST_CUSTOM_CHATS", "")
+ CHAT_SUMMARY_MODEL_ID = os.getenv("CHAT_SUMMARY_MODEL_ID", "") # Specify the model id for `/summary` command (If not set, use the default model)
+ PODCAST_SUMMARY_MODEL_ID = os.getenv("PODCAST_SUMMARY_MODEL_ID", "") # for generating podcast summary (If not set, use the default AI model)
+ SUBTITLE_SUMMARY_MODEL_ID = os.getenv("SUBTITLE_SUMMARY_MODEL_ID", "") # for generating podcast summary (If not set, use the default AI model)
# For tool_call. Some models doesn't support tool call, so we use this model to do the tool_call first.
# Then construct the new questions for the original model.
TOOLS_MODEL = os.getenv("GPT_TOOLS_MODEL", "gpt-4o-mini") # this model should be fast and cheap
@@ -427,6 +430,7 @@ class GEMINI: # Official Gemini
PREFER_LANG = os.getenv("GEMINI_PREFER_LANG", "") # Set a prefer response language for Gemini
MAX_THINKING_BUDGET = int(os.getenv("GEMINI_MAX_THINKING_BUDGET", "24576")) # 24K
CLEAN_FILES_AFTER_SECONDS = int(os.getenv("GEMINI_CLEAN_FILES_AFTER_SECONDS", "172800")) # default to 48 hours
+ ALLOWED_CUSTOM_MODEL_IDS = os.getenv("GEMINI_ALLOWED_CUSTOM_MODEL_IDS", "") # comma separated model ids
# response modality: text
TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-pro")