Commit 2cb4070

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-13 09:41:29
feat(gpt): support chat ids to always use AI models
1 parent 712b891
Changed files (3)
src/llm/gpt.py
@@ -5,7 +5,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, cache
+from config import GEMINI, GPT, PREFIX, TEXT_LENGTH, TID, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.gemini import HELP as AIGC_HELP
 from llm.gemini import gemini_response
@@ -18,6 +18,8 @@ from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix, startswith_prefix
+from permission import slim_cid
+from utils import env_list
 
 HELP = f"""🤖**GPT对话**
 `{PREFIX.GPT}` 后接提示词即可与GPT对话
@@ -40,21 +42,39 @@ HELP = f"""🤖**GPT对话**
 """
 
 
-def is_gpt_conversation(message: Message) -> bool:
-    info = parse_msg(message)
+def is_gpt_conversation(minfo: dict, reply_text: str) -> bool:
     # to avoid potential infinitely loop,
     # we do not respond to bot message & GPT responses.
-    if info["is_bot"]:
+    if minfo["is_bot"]:
         return False
-    if BOT_TIPS in info["text"]:
+    if BOT_TIPS in minfo["text"]:
         return False
-    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, PREFIX.GENIMG]):
+
+    # starts with /prefix
+    if startswith_prefix(minfo["text"], prefix=[PREFIX.GPT, PREFIX.GENIMG]):
+        return True
+
+    # not starts with /prefix, but in specific chat ids
+    if any(str(x) in env_list(TID.OPENAI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/gpt " + minfo["text"]
+        return True
+    if any(str(x) in env_list(TID.GEMINI_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/gemini " + minfo["text"]
+        return True
+    if any(str(x) in env_list(TID.GROK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/grok " + minfo["text"]
+        return True
+    if any(str(x) in env_list(TID.DEEPSEEK_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/ds " + minfo["text"]
         return True
+    if any(str(x) in env_list(TID.QWEN_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/qwen " + minfo["text"]
+        return True
+    if any(str(x) in env_list(TID.DOUBAO_CHATS) for x in [minfo["cid"], slim_cid(minfo["cid"])]):
+        minfo["text"] = "/doubao " + minfo["text"]
+        return True
+
     # is replying to gpt-bot response message?
-    if not message.reply_to_message:
-        return False
-    reply_msg = message.reply_to_message
-    reply_info = parse_msg(reply_msg, silent=True)
     model_names = [
         GPT.OPENAI_MODEL_NAME,
         GPT.DEEPSEEK_MODEL_NAME,
@@ -64,7 +84,7 @@ def is_gpt_conversation(message: Message) -> bool:
         GPT.GEMINI_MODEL_NAME,
         GEMINI.IMG_MODEL_NAME,
     ]
-    return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
+    return startswith_prefix(reply_text, prefix=[f"🤖{x}".lower() for x in model_names])
 
 
 async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
@@ -81,18 +101,18 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     # ruff: noqa: RET502, RET503
     info = parse_msg(message)
     # send docs if message == "/ai", without reply
-    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT]) and not message.reply_to_message:
+    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GPT) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return {}
-    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GENIMG]) and not message.reply_to_message:
+    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GENIMG) and not message.reply_to_message:
         await send2tg(client, message, texts=AIGC_HELP, **kwargs)
         return {}
-    if not is_gpt_conversation(message):
-        return {}
     reply_text = ""
     if message.reply_to_message:
         reply_info = parse_msg(message.reply_to_message, silent=True)
         reply_text = reply_info["text"]
+    if not is_gpt_conversation(info, reply_text):
+        return {}
 
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
src/config.py
@@ -181,6 +181,12 @@ class TID:  # see more TID usecase in `src/permission.py`
     # back up ytdlp audio if the user does not request it
     CHANNEL_YTDLP_BACKUP = os.getenv("TID_CHANNEL_YTDLP_BACKUP", "me")
     DAILY_SUMMARY = os.getenv("TID_DAILY_SUMMARY", "{}")  # {"source-chat-id": "target-chat-id"}, e.g. '{"-1001234567890": "-1009876543210"}'
+    GEMINI_CHATS = os.getenv("TID_GEMINI_CHATS", "")  # comma separated chat ids to always use openai models (no need `/gemini`)
+    OPENAI_CHATS = os.getenv("TID_OPENAI_CHATS", "")  # comma separated chat ids to always use openai models (no need `/gpt`)
+    DEEPSEEK_CHATS = os.getenv("TID_DEEPSEEK_CHATS", "")  # comma separated chat ids to always use openai models (no need `/ds`)
+    QWEN_CHATS = os.getenv("TID_QWEN_CHATS", "")  # comma separated chat ids to always use openai models (no need `/qwen`)
+    DOUBAO_CHATS = os.getenv("TID_DOUBAO_CHATS", "")  # comma separated chat ids to always use openai models (no need `/doubao`)
+    GROK_CHATS = os.getenv("TID_GROK_CHATS", "")  # comma separated chat ids to always use openai models (no need `/grok`)
 
 
 class DB:
src/utils.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 import contextlib
 import json
+import os
 import random
 import re
 import string
@@ -289,6 +290,13 @@ def ts_to_dt(ts: str | float | None) -> datetime | None:
         return None
 
 
+def env_list(value: str | None = None, *, env_key: str = "", separator: str = ",") -> list[str]:
+    """Get list from environment variable."""
+    if value is None:
+        value = os.getenv(env_key, "")
+    return [s.strip() for s in value.split(separator) if s.strip()]
+
+
 def parse_time(timestr: str) -> dict[str, int]:
     """Parse time string.