Commit 35dc3f6

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-10 02:07:48
feat(gpt): add Grok support
`/grok` command is added to support Grok AI model.
1 parent 043b93f
src/llm/contexts.py
@@ -69,7 +69,7 @@ async def single_context(client: Client, message: Message) -> dict:
     def clean_text(text: str) -> str:
         if not text:
             return ""
-        for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao"]:
+        for prefix in [PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]:
             text = text.removeprefix(prefix).strip()
         # remove bot tips
         text = re.sub(rf"(.*?){BOT_TIPS}\)", "", text, flags=re.DOTALL).strip()
src/llm/gpt.py
@@ -18,20 +18,22 @@ from messages.sender import send2tg
 from messages.utils import count_without_entities, equal_prefix, startswith_prefix
 
 HELP = f"""🤖**GPT对话**
-`{PREFIX.GPT}` 命令当前模型:
-- 文本模型: **{GPT.TEXT_MODEL_NAME}**
-- 图片模型: **{GPT.IMAGE_MODEL_NAME}**
-
-`/gpt` 命令强制使用: **{GPT.OPENAI_MODEL_NAME}**
-`/gemini` 命令强制使用: **{GPT.GEMINI_MODEL_NAME}**
-`/ds` 命令强制使用: **{GPT.DEEPSEEK_MODEL_NAME}**
-`/qwen` 命令强制使用: **{GPT.QWEN_MODEL_NAME}**
-`/doubao` 命令强制使用: **{GPT.DOUBAO_MODEL_NAME}**
-
 使用说明:
-1. 在 `{PREFIX.GPT}` 后接提示词即可与GPT对话
+1. `{PREFIX.GPT}` 后接提示词即可与GPT对话
 2. 以 `{PREFIX.GPT}` 回复消息可将其加入上下文
-3. 暂不支持视频/音频模型, 可以先用 `{PREFIX.ASR}` 命令转为文字后再使用 `{PREFIX.GPT}`
+3. 暂不支持视频/音频, 可先用`{PREFIX.ASR}`命令转为文字后再调用`{PREFIX.GPT}`
+
+⚙️模型配置:
+`{PREFIX.GPT}` 命令默认模型: **{GPT.TEXT_MODEL_NAME}**
+如果上下文中包含图片时, 会自动切换为: **{GPT.IMAGE_MODEL_NAME}**
+
+🔄使用以下命令强制切换模型:
+`/gpt`: **{GPT.OPENAI_MODEL_NAME}**
+`/gemini`: **{GPT.GEMINI_MODEL_NAME}**
+`/ds`: **{GPT.DEEPSEEK_MODEL_NAME}**
+`/qwen`: **{GPT.QWEN_MODEL_NAME}**
+`/doubao`: **{GPT.DOUBAO_MODEL_NAME}**
+`/grok`: **{GPT.GROK_MODEL_NAME}**
 """
 
 
@@ -39,14 +41,23 @@ def is_gpt_conversation(message: Message) -> bool:
     info = parse_msg(message)
     if info["is_bot"]:  # do not process bot message
         return False
-    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao"]):
+    if startswith_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao", "/grok"]):
         return True
     # is replying to gpt-bot response message?
     if not message.reply_to_message:
         return False
     reply_msg = message.reply_to_message
     reply_info = parse_msg(reply_msg, silent=True)
-    model_names = [GPT.OPENAI_MODEL_NAME, GPT.GEMINI_MODEL_NAME, GPT.DEEPSEEK_MODEL_NAME, GPT.QWEN_MODEL_NAME, GPT.DOUBAO_MODEL_NAME, GPT.TEXT_MODEL_NAME, GPT.IMAGE_MODEL_NAME]
+    model_names = [
+        GPT.OPENAI_MODEL_NAME,
+        GPT.GEMINI_MODEL_NAME,
+        GPT.DEEPSEEK_MODEL_NAME,
+        GPT.QWEN_MODEL_NAME,
+        GPT.DOUBAO_MODEL_NAME,
+        GPT.GROK_MODEL_NAME,
+        GPT.TEXT_MODEL_NAME,
+        GPT.IMAGE_MODEL_NAME,
+    ]
     return startswith_prefix(reply_info["text"], prefix=[f"🤖{x}".lower() for x in model_names])
 
 
@@ -61,7 +72,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
     # ruff: noqa: RET502, RET503
     info = parse_msg(message)
     # send docs if message == "/ai", without reply
-    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/doubao"]) and not message.reply_to_message:
+    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=[PREFIX.GPT, "/gpt", "/gemini", "/ds", "/qwen", "/grok", "/doubao"]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
     if not is_gpt_conversation(message):
@@ -72,7 +83,6 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = G
         reply_info = parse_msg(message.reply_to_message, silent=True)
         reply_text = reply_info["text"]
 
-    # /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
     force_model = parse_force_model(info["text"], reply_text)
 
     # cache media_group message, only process once
src/llm/models.py
@@ -30,7 +30,7 @@ def get_context_type(conversations: list[Message]) -> dict:
 def parse_force_model(text: str, reply_text: str) -> str:
     """Parse the force model from the text or reply text.
 
-    /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao
+    /gpt = OpenAI, /gemini = Gemini, /ds = DeepSeek, /qwen = Qwen, /doubao = Doubao, /grok = Grok
     """
     force_model = ""
     # parse from bot reply
@@ -44,7 +44,8 @@ def parse_force_model(text: str, reply_text: str) -> str:
         force_model = GPT.QWEN_MODEL
     elif reply_text.startswith(f"🤖{GPT.DOUBAO_MODEL_NAME}"):
         force_model = GPT.DOUBAO_MODEL
-
+    elif reply_text.startswith(f"🤖{GPT.GROK_MODEL_NAME}"):
+        force_model = GPT.GROK_MODEL
     # parse from command prefix
     if startswith_prefix(text, prefix=["/gpt"]):
         force_model = GPT.OPENAI_MODEL
@@ -56,7 +57,8 @@ def parse_force_model(text: str, reply_text: str) -> str:
         force_model = GPT.QWEN_MODEL
     elif startswith_prefix(text, prefix=["/doubao"]):
         force_model = GPT.DOUBAO_MODEL
-
+    elif startswith_prefix(text, prefix=["/grok"]):
+        force_model = GPT.GROK_MODEL
     return force_model
 
 
@@ -98,6 +100,7 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
         GPT.DEEPSEEK_MODEL: {"api_key": GPT.DEEPSEEK_API_KEY, "base_url": GPT.DEEPSEEK_BASE_URL, "model_name": GPT.DEEPSEEK_MODEL_NAME},
         GPT.QWEN_MODEL: {"api_key": GPT.QWEN_API_KEY, "base_url": GPT.QWEN_BASE_URL, "model_name": GPT.QWEN_MODEL_NAME},
         GPT.DOUBAO_MODEL: {"api_key": GPT.DOUBAO_API_KEY, "base_url": GPT.DOUBAO_BASE_URL, "model_name": GPT.DOUBAO_MODEL_NAME},
+        GPT.GROK_MODEL: {"api_key": GPT.GROK_API_KEY, "base_url": GPT.GROK_BASE_URL, "model_name": GPT.GROK_MODEL_NAME},
     }
     model_factory |= {GPT.SUMMARY_MODEL: {"api_key": GPT.SUMMARY_API_KEY, "base_url": GPT.SUMMARY_BASE_URL, "model_name": GPT.SUMMARY_MODEL_NAME}}
     model_factory |= {GPT.LONG_MODEL: {"api_key": GPT.LONG_API_KEY, "base_url": GPT.LONG_BASE_URL, "model_name": GPT.LONG_MODEL_NAME}}
@@ -116,6 +119,7 @@ def get_gpt_config(model_type: str, contexts: list[dict], force_model: str = "")
             or (force_model == GPT.DOUBAO_MODEL and GPT.DOUBAO_IMAGE_CAPABILITY)
             or (force_model == GPT.SUMMARY_MODEL and GPT.SUMMARY_IMAGE_CAPABILITY)
             or (force_model == GPT.LONG_MODEL and GPT.LONG_IMAGE_CAPABILITY)
+            or (force_model == GPT.GROK_MODEL and GPT.GROK_IMAGE_CAPABILITY)
         )
     ):
         client |= force_model_config
src/config.py
@@ -194,6 +194,12 @@ class GPT:  # see `llm/README.md`
     QWEN_API_KEY = os.getenv("GPT_QWEN_API_KEY", "")
     QWEN_BASE_URL = os.getenv("GPT_QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
     QWEN_IMAGE_CAPABILITY = os.getenv("GPT_QWEN_IMAGE_CAPABILITY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    # /grok command
+    GROK_MODEL = os.getenv("GPT_GROK_MODEL", "grok-3")
+    GROK_MODEL_NAME = os.getenv("GPT_GROK_MODEL_NAME", "Grok-3")
+    GROK_API_KEY = os.getenv("GPT_GROK_API_KEY", "")
+    GROK_BASE_URL = os.getenv("GPT_GROK_BASE_URL", "https://api.x.ai/v1")
+    GROK_IMAGE_CAPABILITY = os.getenv("GPT_GROK_IMAGE_CAPABILITY", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     # /doubao command
     DOUBAO_MODEL = os.getenv("GPT_DOUBAO_MODEL", "doubao-1-5-vision-pro-32k-250115")
     DOUBAO_MODEL_NAME = os.getenv("GPT_DOUBAO_MODEL_NAME", "豆包-1.5-Pro")
src/handler.py
@@ -75,7 +75,7 @@ async def handle_utilities(
     """
     kwargs |= {"target_chat": target_chat, "reply_msg_id": reply_msg_id, "show_progress": show_progress, "detail_progress": detail_progress}
     if ai:
-        await gpt_response(client, message, **kwargs)  # /ai /gpt /gemini /ds /qwen /doubao
+        await gpt_response(client, message, **kwargs)  # /ai /gpt /gemini /ds /qwen /doubao /grok
     if asr:
         await voice_to_text(client, message, **kwargs)  # /asr
     if audio:
@@ -140,7 +140,7 @@ async def handle_social_media(
         cmd_prefix.extend(PREFIX.MAIN)
     ignore_prefix = ignore_prefix or ["/dl4dw"]
     # these commands are handled in `handle_utilities`
-    ignore_prefix.extend(["/ai", "/asr", "/audio", "/combine", "/doubao", "/ds", "/gemini", "/gpt", "/ocr", "/price", "/qwen", "/subtitle", "/summary", "/voice", "/wget"])
+    ignore_prefix.extend(["/ai", "/asr", "/audio", "/combine", "/doubao", "/ds", "/gemini", "/gpt", "/ocr", "/price", "/qwen", "/grok", "/subtitle", "/summary", "/voice", "/wget"])
     info = parse_msg(message)
     this_texts = info["text"]  # texts of the trigger message
     if startswith_prefix(this_texts, prefix=ignore_prefix):
@@ -270,7 +270,7 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
         msg += "\n🅱️哔哩哔哩"
         msg += "\n🆕和所有yt-dlp支持的链接\n"
     if permission["ai"]:
-        msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao` + 提示词"
+        msg += f"\n🤖**GPT对话**: `{PREFIX.GPT} /gpt /gemini /ds /qwen /doubao /grok` + 提示词"
     if permission["asr"]:
         msg += f"\n🗣**语音转文字**: `{PREFIX.ASR}` 回复语音消息"
     if permission["audio"]: