Commit ef1ae88

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-03-10 15:14:53
feat(gpt): support stream mode
1 parent 51e0cbe
src/llm/gpt.py
@@ -9,6 +9,7 @@ from config import GPT, PREFIX, cache
 from llm.contexts import get_conversation_contexts, get_conversations
 from llm.models import get_model_config_with_contexts, get_model_type
 from llm.response import merge_tools_response, send_to_gpt
+from llm.response_stream import send_to_gpt_stream
 from llm.utils import BOT_TIPS, llm_cleanup_files
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -47,12 +48,13 @@ def is_gpt_conversation(message: Message) -> bool:
 
 
 @cache.memoize(ttl=60)
-async def gpt_response(client: Client, message: Message, **kwargs):
+async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = GPT.STREAM_MODE, **kwargs):
     """Get GPT response from Various API.
 
     Args:
         client (Client): The Pyrogram client.
         message (Message): The trigger message object.
+        gpt_stream (bool): Whether to use stream mode.
     """
     # ruff: noqa: RET502, RET503
     info = parse_msg(message)
@@ -101,23 +103,28 @@ async def gpt_response(client: Client, message: Message, **kwargs):
     contexts = await get_conversation_contexts(client, conversations)
     config = get_model_config_with_contexts(model_type, contexts, force_model, info)
     msg = f"🤖{config['friendly_name']}: 思考中..."
-    if kwargs.get("show_progress"):
-        res = await send2tg(client, message, texts=msg, **kwargs)
-        kwargs["progress"] = res[0]
+    status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
+    kwargs["progress"] = status_msg
 
-    config, tool_response = await merge_tools_response(config, **kwargs)
+    config, response = await merge_tools_response(config, **kwargs)
     # skip send a new request if tool_model is the same as the current model
-    if tool_response and config["completions"]["model"] == GPT.TOOLS_MODEL:
-        response = tool_response
-    else:
+    if response and config["completions"]["model"] == GPT.TOOLS_MODEL and response.get("content"):
+        texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{response['content']}"
+        await modify_progress(message=status_msg, del_status=True, **kwargs)
+        llm_cleanup_files(config["completions"]["messages"])
+        return
+
+    if not gpt_stream:
         response = await send_to_gpt(config, **kwargs)
-    if content := response.get("content"):
-        if reasoning := response.get("reasoning"):
-            content = f"{reasoning}\n{content}"
-            texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n{content}"
-        else:
-            texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
-        logger.debug(texts)
-        await send2tg(client, message, texts=texts, **kwargs)
-        await modify_progress(del_status=True, **kwargs)
+        if content := response.get("content"):
+            if reasoning := response.get("reasoning"):
+                content = f"{reasoning}\n{content}"
+                texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n{content}"
+            else:
+                texts = f"🤖**{response['model']}**: ({BOT_TIPS})\n\n{content}"
+            logger.debug(texts)
+            await send2tg(client, message, texts=texts, **kwargs)
+            await modify_progress(message=status_msg, del_status=True, **kwargs)
+    else:
+        return await send_to_gpt_stream(client, status_msg, config, **kwargs)  # type: ignore
     llm_cleanup_files(config["completions"]["messages"])
src/llm/response.py
@@ -13,9 +13,8 @@ from config import GPT
 from llm.models import openrouter_hook
 from llm.prompts import add_search_results_to_prompts
 from llm.tools import add_tools, get_online_search_result
-from llm.utils import beautify_llm_response, beautify_model_name, extract_reasoning
+from llm.utils import add_search_results_to_response, beautify_llm_response, beautify_model_name, extract_reasoning
 from messages.progress import modify_progress
-from utils import number_to_emoji
 
 
 async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
@@ -58,7 +57,7 @@ async def merge_tools_response(config: dict, **kwargs) -> tuple[dict, dict]:
 
 
 async def send_to_gpt(config: dict, retry: int = 0, **kwargs) -> dict[str, str]:
-    """Get GPT response.
+    """Get GPT response in non-stream mode.
 
     # See `llm/README.md` for more details.
 
@@ -152,16 +151,3 @@ async def parse_response(config: dict, response: dict) -> dict[str, str]:
         logger.error(f"Parse  GPT response failed: {e}")
         raise
     return response
-
-
-def add_search_results_to_response(search_results: list[dict], response: str) -> str:
-    """Add search results to response."""
-    if not search_results or not response:
-        return response
-    response = response.strip()
-    for idx, result in enumerate(search_results):
-        title = result.get("title", "")[:20]
-        link = result.get("link", "")
-        if link.startswith("http") and f"({link})" in response:
-            response += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
-    return response.strip()
src/llm/response_stream.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import json
+import re
+
+from glom import glom
+from loguru import logger
+from openai import AsyncOpenAI
+from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
+from pyrogram.types import Message
+
+from config import GPT, TEXT_LENGTH
+from llm.utils import BOT_TIPS, add_search_results_to_response
+from messages.progress import modify_progress
+from messages.utils import count_without_entities, smart_split
+
+
+async def send_to_gpt_stream(client: Client, status: Message, config: dict, retry: int = 0, **kwargs) -> dict:
+    """Get GPT response in stream mode.
+
+    Returns:
+        {"content": str, "reasoning": str, "model": str}
+    """
+    # ruff: noqa: RUF001, RUF003
+    prefix = f"🤖**{config['friendly_name']}**: ({BOT_TIPS})\n"
+    try:
+        openai = AsyncOpenAI(**config["client"])
+        logger.trace(config)
+        answers = prefix
+        sent_answers = []
+        is_reasoning = False
+        reasoning_in_response = None
+        gen = await openai.chat.completions.create(**config["completions"], stream=True)
+        async for chunk in gen:
+            resp = chunk.model_dump()
+            logger.trace(resp)
+            error = await parse_error(resp, retry, **kwargs)
+            if error["retry"]:
+                return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
+            if error["error"]:
+                await modify_progress(message=status, text=error["error"], force_update=True, **kwargs)
+                return {}
+            answer = glom(resp, "choices.0.delta.content", default="") or ""
+            reasoning_content = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
+            if reasoning_in_response is None and reasoning_content:
+                reasoning_in_response = True
+            if reasoning_content and not is_reasoning:  # 首次收到推理内容
+                is_reasoning = True
+                answers += f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔{reasoning_content.lstrip()}"
+            elif reasoning_content and is_reasoning:  # 收到推理内容且正在思考
+                answers += reasoning_content
+            elif reasoning_in_response is True and is_reasoning:  # 收到回答, 关闭推理标志
+                is_reasoning = False
+                answers = f"{answers.rstrip()}💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}\n" + answer.lstrip()
+            else:
+                answers += answer
+
+            # Sometimes the reasoning content is included in the content field.
+            # handle "<think>...</think>\n\n"
+            if answers.removeprefix(prefix).lstrip().startswith("<think>"):
+                is_reasoning = True
+                answers = answers.replace("<think>", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔")
+            if "</think>" in answers:
+                is_reasoning = False
+                answers = re.sub(r"</think>\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1)
+
+            # handle ">Reasoning...Reasoned(.*?)seconds"
+            if re.search(r"^>?\s*Reasoning", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+                is_reasoning = True
+                answers = re.sub(r">?\s*Reasoning\s*", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔", answers, count=1, flags=re.DOTALL)
+            if re.search(r"🤔(.*?)Reasoned(.*?)seconds", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+                is_reasoning = False
+                answers = re.sub(r"Reasoned(.*?)seconds\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1, flags=re.DOTALL)
+
+            # handle ">正在推理...,持续(.*?)秒"
+            if re.search(r"^>?(正在)?推理", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+                is_reasoning = True
+                answers = re.sub(r">?(正在)?推理\s*", f"{BLOCKQUOTE_EXPANDABLE_DELIM}🤔", answers, count=1, flags=re.DOTALL)
+            if re.search(r"🤔(.*?),持续(.*?)秒", answers.removeprefix(prefix).lstrip(), re.DOTALL):
+                is_reasoning = False
+                answers = re.sub(r",持续(.*?)秒\s*", f"💡\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}", answers, count=1, flags=re.DOTALL)
+
+            if await count_without_entities(answers) <= TEXT_LENGTH:
+                await modify_progress(message=status, text=answers, detail_progress=True)
+            else:  # answers is too long, split it into multiple messages
+                parts = await smart_split(answers)
+                await modify_progress(message=status, text=parts[0], force_update=True)  # force send the first part
+                sent_answers.append(parts[0])
+                answers = parts[-1]  # keep the last part
+                if is_reasoning:
+                    answers = f"{BLOCKQUOTE_EXPANDABLE_DELIM}{answers.lstrip()}"
+                status = await client.send_message(status.chat.id, answers)
+
+        sent_answers.append(answers)
+        answers = add_search_results_to_response(config.get("search_results", []), "".join(sent_answers))
+        answers = (await smart_split(answers))[-1]
+        # Finally, force update the message
+        await modify_progress(message=status, text=answers.strip(), force_update=True)
+
+    except Exception as e:
+        error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
+        logger.error(error)
+        await modify_progress(text=error, force_update=True, **kwargs)
+        if retry < GPT.MAX_RETRY:
+            return await send_to_gpt_stream(client, status, config, retry=retry + 1, **kwargs)
+    return {}
+
+
+async def parse_error(resp: dict, retry: int, **kwargs) -> dict:
+    """Parse GPT error.
+
+    Returns:
+        {"error": bool, "retry": bool}
+    """
+    error_result = {"error": False, "retry": False}
+    error_code = glom(resp, "error.code", default=0)
+    error_msg = ""
+    content = None
+    reasoning_content = None
+    is_finished = False
+    tool_call = {}
+    with contextlib.suppress(Exception):
+        metadata = glom(resp, "error.metadata.raw", default="{}")
+        error_msg = glom(json.loads(metadata), "error.message", default="")
+        choice = glom(resp, "choices.0", default={})
+        content = glom(choice, "delta.content", default=None)
+        reasoning_content = glom(choice, "delta.reasoning_content", default=None)
+        tool_call = glom(choice, "delta.tool_calls.0", default=None)
+        is_finished = glom(choice, "finish_reason", default="") == "stop"
+    if is_finished or any(x is not None for x in [content, reasoning_content, tool_call]):
+        return {"error": False, "retry": False}
+    if error_code != 0:
+        logger.warning(resp)
+        error_result["error"] = True
+        await modify_progress(text=f"[{error_code}] {error_msg}\n重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}", force_update=True, **kwargs)
+        if retry < GPT.MAX_RETRY:
+            error_result["retry"] = True
+    return error_result
src/llm/utils.py
@@ -7,7 +7,7 @@ import tiktoken
 from loguru import logger
 
 from config import DOWNLOAD_DIR, GPT
-from utils import remove_consecutive_newlines, remove_dash, remove_pound
+from utils import number_to_emoji, remove_consecutive_newlines, remove_dash, remove_pound
 
 BOT_TIPS = "回复以继续"
 
@@ -108,19 +108,32 @@ def extract_reasoning(text: str) -> tuple[str, str]:
     {content}"
     """
     reasoning = ""
-    if matched := re.search(r"<think>(.*?)</think>", text, re.DOTALL):
+    if matched := re.search(r"^<think>(.*?)</think>", text.lstrip(), re.DOTALL):
         reasoning = matched.group(1)
         text = re.sub(r"<think>(.*?)</think>", "", text, count=1, flags=re.DOTALL)  # remove <think>...</think>
-    if matched := re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL):
+    if matched := re.search(r"^<thinking>(.*?)</thinking>", text.lstrip(), re.DOTALL):
         reasoning = matched.group(1)
         text = re.sub(r"<thinking>(.*?)</thinking>", "", text, count=1, flags=re.DOTALL)
 
     # Reverse engineered Web API
-    if matched := re.search(r"^>?(正在)?推理(.*?)(,持续.*?)秒\n\n(.*)", text, re.DOTALL):  # noqa: RUF001
+    if matched := re.search(r"^>?(正在)?推理(.*?)(,持续.*?)秒\n\n(.*)", text.lstrip(), re.DOTALL):  # noqa: RUF001
         reasoning = matched.group(2)
         text = matched.group(4)
-    if matched := re.search(r"^>?\s?Reasoning(.*?)Reasoned(.*?)seconds\n\n(.*)", text, re.DOTALL):
+    if matched := re.search(r"^>?\s?Reasoning(.*?)Reasoned(.*?)seconds\n\n(.*)", text.lstrip(), re.DOTALL):
         reasoning = matched.group(1)
         text = matched.group(3)
 
     return reasoning.strip(), text.strip().removeprefix("{content}").strip()
+
+
+def add_search_results_to_response(search_results: list[dict], response: str) -> str:
+    """Add search results to response."""
+    if not search_results or not response:
+        return response
+    response = response.strip()
+    for idx, result in enumerate(search_results):
+        title = result.get("title", "")[:20]
+        link = result.get("link", "")
+        if link.startswith("http") and f"({link})" in response:
+            response += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
+    return response.strip()
src/messages/progress.py
@@ -5,6 +5,7 @@ import asyncio
 from pathlib import Path
 
 from loguru import logger
+from pyrogram.errors import FloodWait, MessageNotModified
 from pyrogram.types import Message
 
 from config import TEXT_LENGTH, cache
@@ -16,7 +17,8 @@ async def modify_progress(
     *,
     detail_progress: bool = False,
     del_status: bool = False,
-    del_delay: int = 0,
+    ttl: float = 2,
+    del_delay: float = 0,
     force_update: bool = False,
     **kwargs,
 ):
@@ -27,7 +29,8 @@ async def modify_progress(
         text (str): The new text to update.
         detail_progress(bool): Whether to show the detail progress.
         del_status (bool): Whether the progress is done.
-        del_delay (int): Delay seconds to delete the message.
+        ttl (float): Time to live for the cache.
+        del_delay (float): Delay seconds to delete the message.
         force_update (bool): Force update the message.
     """
     if message is None:
@@ -50,7 +53,12 @@ async def modify_progress(
             return
         logger.trace(f"Progress: {text!r}")
         await message.edit_text(text[:TEXT_LENGTH])
-        cache.set("modify_progress", "1", ttl=2)
+        cache.set("modify_progress", "1", ttl=ttl)
+    except FloodWait as e:
+        logger.warning(e)
+        await asyncio.sleep(e.value)  # type: ignore
+    except MessageNotModified:
+        pass
     except Exception as e:
         logger.warning(f"modify_progress: {e}")
 
src/config.py
@@ -135,6 +135,7 @@ class COOKIE:  # See: https://github.com/easychen/CookieCloud
 
 
 class GPT:  # see `llm/README.md`
+    STREAM_MODE = os.getenv("GPT_STREAM_MODE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o")
     IMAGE_MODEL = os.getenv("GPT_IMAGE_MODEL", "gpt-4o")
     VIDEO_MODEL = os.getenv("GPT_VIDEO_MODEL", "glm-4v-plus")