Commit ef1ae88
Changed files (6)
src
src/llm/gpt.py
@@ -9,6 +9,7 @@ from config import GPT, PREFIX, cache
from llm.contexts import get_conversation_contexts, get_conversations
from llm.models import get_model_config_with_contexts, get_model_type
from llm.response import merge_tools_response, send_to_gpt
+from llm.response_stream import send_to_gpt_stream
from llm.utils import BOT_TIPS, llm_cleanup_files
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -47,12 +48,13 @@ def is_gpt_conversation(message: Message) -> bool:
@cache.memoize(ttl=60)
-async def gpt_response(client: Client, message: Message, **kwargs):
+async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
"""Get GPT response from Various API.
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
+ gpt_stream (bool): Whether to use stream mode.
"""
# ruff: noqa: RET502, RET503
info = parse_msg(message)
@@ -101,23 +103,28 @@ async def gpt_response(client: Client, message: Message, **kwargs):
contexts = await get_conversation_contexts(client, conversations)
config = get_model_config_with_contexts(model_type, contexts, force_model, info)
msg = f"🤖{config['friendly_name']}: 思考中..."
- if kwargs.get("show_progress"):
- res = await send2tg(client, message, texts=msg, **kwargs)
- kwargs["progress"] = res[0]
+ status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+ kwargs["progress"] = status_msg
- config, tool_response = await merge_tools_response(config, **kwargs)
+ config, response = await merge_tools_response(config, **kwargs)
# skip send a new request if tool_model is the same as the current model
- if tool_response and config["completions"]["model"] == GPT.TOOLS_MODEL:
- response = tool_response
- else:
+ if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
+ texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{response['content']}"
+ await modify_progress(message=status_msg, del_status=True, **kwargs)
+ llm_cleanup_files(config["completions"]["messages"])
+ return
+
+ if not gpt_stream:
response = await send_to_gpt(config, **kwargs)
- if content := response.get("content"):
- if reasoning := response.get("reasoning"):
- content = f"{reasoning}\n{content}"
- texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n{content}"
- else:
- texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
- logger.debug(texts)
- await send2tg(client, message, texts=texts, **kwargs)
- await modify_progress(del_status=True, **kwargs)
+ if content := response.get("content"):
+ if reasoning := response.get("reasoning"):
+ content = f"{reasoning}\n{content}"
+ texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n{content}"
+ else:
+ texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
+ logger.debug(texts)
+ await send2tg(client, message, texts=texts, **kwargs)
+ await modify_progress(message=status_msg, del_status=True, **kwargs)
+ else:
+ return await send_to_gpt_stream(client, status_msg, config, **kwargs) # type: ignore
llm_cleanup_files(config["completions"]["messages"])
src/llm/response.py
@@ -13,9 +13,8 @@ from config import GPT
from llm.models import openrouter_hook
from llm.prompts import add_search_results_to_prompts
from llm.tools import add_tools, get_online_search_result
-from llm.utils import beautify_llm_response, beautify_model_name, extract_reasoning
+from llm.utils import add_search_results_to_response, beautify_llm_response, beautify_model_name, extract_reasoning
from messages.progress import modify_progress
-from utils import number_to_emoji
async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
@@ -58,7 +57,7 @@ async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
- """Get GPT response.
+ """Get GPT response in non-stream mode.
# See `llm/README.md` for more details.
@@ -152,16 +151,3 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
logger.error(f"Parse GPT response failed: {e}")
raise
return response
-
-
-def add_search_results_to_response(search_results: list[dict], response: str) -> str:
- """Add search results to response."""
- if not search_results or not response:
- return response
- response = response.strip()
- for idx, result in enumerate(search_results):
- title = result.get("title", "")[:20]
- link = result.get("link", "")
- if link.startswith("http") and f"({link})" in response:
- response += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
- return response.strip()
src/llm/response_stream.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import json
+import re
+
+from glom import glom
+from loguru import logger
+from openai import AsyncOpenAI
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message
+
+from config import GPT, TEXT_LENGTH
+from llm.utils import BOT_TIPS, add_search_results_to_response
+from messages.progress import modify_progress
+from messages.utils import count_without_entities, smart_split
+
+
+async def send_to_gpt_stream(client: Client, status: Message, config: dict, retry: int = 0, **kwargs) -> dict:
+ """Get GPT response in stream mode.
+
+ Returns:
+ {"content": str, "reasoning": str, "model": str}
+ """
+ # ruff: noqa: RUF001, RUF003
+ prefix = f"🤖**{config['friendly_name']}**: ({BOT_TIPS})\n"
+ try:
+ openai = AsyncOpenAI(**config["client"])
+ logger.trace(config)
+ answers = prefix
+ sent_answers = []
+ is_reasoning = False
+ reasoning_in_response = None
+ gen = await openai.chat.completions.create(**config["completions"], stream=True)
+ async for chunk in gen:
+ resp = chunk.model_dump()
+ logger.trace(resp)
+ error = await parse_error(resp, retry, **kwargs)
+ if error["retry"]:
+ return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
+ if error["error"]:
+ await modify_progress(message=status, text=error["error"], force_update=True, **kwargs)
+ return {}
+ answer = glom(resp, "choices.0.delta.content", default="") or ""
+ reasoning_content = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
+ if reasoning_in_response is None and reasoning_content:
+ reasoning_in_response = True
+ if reasoning_content and not is_reasoning: # 首次收到推理内容
+ is_reasoning = True
+ answers += f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔{reasoning_content.lstrip()}"
+ elif reasoning_content and is_reasoning: # 收到推理内容且正在思考
+ answers += reasoning_content
+ elif reasoning_in_response is True and is_reasoning: # 收到回答, 关闭推理标志
+ is_reasoning = False
+ answers = f"{answers.rstrip()}💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+ else:
+ answers += answer
+
+ # Sometimes the reasoning content is included in the content field.
+ # handle "<think>...</think>\n\n"
+ if answers.removeprefix(prefix).lstrip().startswith("<think>"):
+ is_reasoning = True
+ answers = answers.replace("<think>", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔")
+ if "</think>" in answers:
+ is_reasoning = False
+ answers = re.sub(r"</think>\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1)
+
+ # handle ">Reasoning...Reasoned(.*?)seconds"
+ if re.search(r"^>?\s*Reasoning", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+ is_reasoning = True
+ answers = re.sub(r">?\s*Reasoning\s*", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔", answers, count=1, flags=re.DOTALL)
+ if re.search(r"🤔(.*?)Reasoned(.*?)seconds", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+ is_reasoning = False
+ answers = re.sub(r"Reasoned(.*?)seconds\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1, flags=re.DOTALL)
+
+ # handle ">正在推理...,持续(.*?)秒"
+ if re.search(r"^>?(正在)?推理", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+ is_reasoning = True
+ answers = re.sub(r">?(正在)?推理\s*", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔", answers, count=1, flags=re.DOTALL)
+ if re.search(r"🤔(.*?),持续(.*?)秒", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+ is_reasoning = False
+ answers = re.sub(r",持续(.*?)秒\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1, flags=re.DOTALL)
+
+ if await count_without_entities(answers) <= TEXT_LENGTH:
+ await modify_progress(message=status, text=answers, detail_progress=True)
+ else: # answers is too long, split it into multiple messages
+ parts = await smart_split(answers)
+ await modify_progress(message=status, text=parts[0], force_update=True) # force send the first part
+ sent_answers.append(parts[0])
+ answers = parts[-1] # keep the last part
+ if is_reasoning:
+ answers = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{answers.lstrip()}"
+ status = await client.send_message(status.chat.id, answers)
+
+ sent_answers.append(answers)
+ answers = add_search_results_to_response(config.get("search_results", []), "".join(sent_answers))
+ answers = (await smart_split(answers))[-1]
+ # Finally, force update the message
+ await modify_progress(message=status, text=answers.strip(), force_update=True)
+
+ except Exception as e:
+ error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
+ logger.error(error)
+ await modify_progress(text=error, force_update=True, **kwargs)
+ if retry < GPT.MAX_RETRY:
+ return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
+ return {}
+
+
+async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
+ """Parse GPT error.
+
+ Returns:
+ {"error": bool, "retry": bool}
+ """
+ error_result = {"error": False, "retry": False}
+ error_code = glom(resp, "error.code", default=0)
+ error_msg = ""
+ content = None
+ reasoning_content = None
+ is_finished = False
+ tool_call = {}
+ with contextlib.suppress(Exception):
+ metadata = glom(resp, "error.metadata.raw", default="{}")
+ error_msg = glom(json.loads(metadata), "error.message", default="")
+ choice = glom(resp, "choices.0", default={})
+ content = glom(choice, "delta.content", default=None)
+ reasoning_content = glom(choice, "delta.reasoning_content", default=None)
+ tool_call = glom(choice, "delta.tool_calls.0", default=None)
+ is_finished = glom(choice, "finish_reason", default="") == "stop"
+ if is_finished or any(x is not None for x in [content, reasoning_content, tool_call]):
+ return {"error": False, "retry": False}
+ if error_code != 0:
+ logger.warning(resp)
+ error_result["error"] = True
+ await modify_progress(text=f"[{error_code}] {error_msg}\n重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}", force_update=True, **kwargs)
+ if retry < GPT.MAX_RETRY:
+ error_result["retry"] = True
+ return error_result
src/llm/utils.py
@@ -7,7 +7,7 @@ import tiktoken
from loguru import logger
from config import DOWNLOAD_DIR, GPT
-from utils import remove_consecutive_newlines, remove_dash, remove_pound
+from utils import number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound
BOT_TIPS = "回复以继续"
@@ -108,19 +108,32 @@ def extract_reasoning(text: str) -> tuple[str, str]:
{content}"
"""
reasoning = ""
- if matched := re.search(r"<think>(.*?)</think>", text, re.DOTALL):
+ if matched := re.search(r"^<think>(.*?)</think>", text.lstrip(), re.DOTALL):
reasoning = matched.group(1)
text = re.sub(r"<think>(.*?)</think>", "", text, count=1, flags=re.DOTALL) # remove <think>...</think>
- if matched := re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL):
+ if matched := re.search(r"^<thinking>(.*?)</thinking>", text.lstrip(), re.DOTALL):
reasoning = matched.group(1)
text = re.sub(r"<thinking>(.*?)</thinking>", "", text, count=1, flags=re.DOTALL)
# Reverse engineered Web API
- if matched := re.search(r"^>?(正在)?推理(.*?)(,持续.*?)秒\n\n(.*)", text, re.DOTALL): # noqa: RUF001
+ if matched := re.search(r"^>?(正在)?推理(.*?)(,持续.*?)秒\n\n(.*)", text.lstrip(), re.DOTALL): # noqa: RUF001
reasoning = matched.group(2)
text = matched.group(4)
- if matched := re.search(r"^>?\s?Reasoning(.*?)Reasoned(.*?)seconds\n\n(.*)", text, re.DOTALL):
+ if matched := re.search(r"^>?\s?Reasoning(.*?)Reasoned(.*?)seconds\n\n(.*)", text.lstrip(), re.DOTALL):
reasoning = matched.group(1)
text = matched.group(3)
return reasoning.strip(), text.strip().removeprefix("{content}").strip()
+
+
+def add_search_results_to_response(search_results: list[dict], response: str) -> str:
+ """Add search results to response."""
+ if not search_results or not response:
+ return response
+ response = response.strip()
+ for idx, result in enumerate(search_results):
+ title = result.get("title", "")[:20]
+ link = result.get("link", "")
+ if link.startswith("http") and f"({link})" in response:
+ response += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
+ return response.strip()
src/messages/progress.py
@@ -5,6 +5,7 @@ import asyncio
from pathlib import Path
from loguru import logger
+from pyrogram.errors import FloodWait, MessageNotModified
from pyrogram.types import Message
from config import TEXT_LENGTH, cache
@@ -16,7 +17,8 @@ async def modify_progress(
*,
detail_progress: bool = False,
del_status: bool = False,
- del_delay: int = 0,
+ ttl: float = 2,
+ del_delay: float = 0,
force_update: bool = False,
**kwargs,
):
@@ -27,7 +29,8 @@ async def modify_progress(
text (str): The new text to update.
detail_progress(bool): Whether to show the detail progress.
del_status (bool): Whether the progress is done.
- del_delay (int): Delay seconds to delete the message.
+ ttl (float): Time to live for the cache.
+ del_delay (float): Delay seconds to delete the message.
force_update (bool): Force update the message.
"""
if message is None:
@@ -50,7 +53,12 @@ async def modify_progress(
return
logger.trace(f"Progress: {text!r}")
await message.edit_text(text[:TEXT_LENGTH])
- cache.set("modify_progress", "1", ttl=2)
+ cache.set("modify_progress", "1", ttl=ttl)
+ except FloodWait as e:
+ logger.warning(e)
+ await asyncio.sleep(e.value) # type: ignore
+ except MessageNotModified:
+ pass
except Exception as e:
logger.warning(f"modify_progress: {e}")
src/config.py
@@ -135,6 +135,7 @@ class COOKIE: # See: https://github.com/easychen/CookieCloud
class GPT: # see `llm/README.md`
+ STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
VIDEO_MODEL = os.getenv("GPT_VIDEO_MODEL", "glm-4v-plus")