main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4
  5from glom import glom
  6from loguru import logger
  7from openai import AsyncOpenAI, DefaultAsyncHttpxClient
  8from pyrogram.client import Client
  9from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 10from pyrogram.types import Message, ReplyParameters
 11
 12from ai.texts.contexts import get_openai_completion_contexts
 13from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, split_reasoning, trim_none
 14from config import AI, PROXY, TEXT_LENGTH
 15from messages.progress import modify_progress
 16from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
 17from utils import strings_list
 18
 19
 20async def openai_chat_completions(
 21    client: Client,
 22    message: Message,
 23    *,
 24    prefix: str = "",
 25    model_id: str = AI.OPENAI_MODEL_ID,
 26    model_name: str = AI.OPENAI_MODEL_ID,
 27    openai_base_url: str = AI.OPENAI_BASE_URL,
 28    openai_api_keys: str = AI.OPENAI_API_KEYS,
 29    openai_client_config: str | dict = "",
 30    openai_default_headers: str | dict = "",
 31    openai_completions_config: str | dict = "",
 32    openai_proxy: str | None = PROXY.OPENAI,
 33    openai_system_prompt: str = "",
 34    openai_contexts: list[dict] | None = None,
 35    openai_tools: list[dict] | None = None,
 36    skills: str = "",
 37    hide_thinking: bool = False,
 38    silent: bool = False,
 39    max_retries: int = 3,
 40    **kwargs,
 41) -> dict:
 42    """Get OpenAI Chat Completions.
 43
 44    Returns:
 45        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str,  "sent_messages": list[Message]}
 46    """
 47    if not prefix:
 48        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 49
 50    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 51        status_msg = None
 52    else:
 53        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 54
 55    sent_messages = [status_msg]
 56    try:
 57        openai_client = {}
 58        if literal_eval(openai_client_config):
 59            openai_client |= literal_eval(openai_client_config)
 60        if literal_eval(openai_default_headers):
 61            openai_client |= {"default_headers": literal_eval(openai_default_headers)}
 62        if openai_proxy:
 63            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
 64        contexts = openai_contexts or await get_openai_completion_contexts(client, message)
 65        if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
 66            contexts.insert(0, {"role": "system", "content": openai_system_prompt})
 67        if skills:
 68            contexts = inject_skills(contexts, skills=await load_skills(skills))
 69        params = {"model": model_id, "messages": contexts, "stream": True}
 70        if literal_eval(openai_completions_config):
 71            params |= literal_eval(openai_completions_config)
 72        if openai_tools:
 73            params |= {"tools": openai_tools, "tool_choice": "auto"}
 74        logger.debug(f"openai.chat.completions.create(**{params})")
 75    except Exception as e:
 76        logger.error(f"OpenAI client setup error: {e}")
 77        await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
 78        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 79
 80    for api_key in strings_list(openai_api_keys, shuffle=True):
 81        try:
 82            openai_client |= {"base_url": openai_base_url, "api_key": api_key}
 83            logger.trace(f"AsyncOpenAI(**{openai_client})")
 84            openai = AsyncOpenAI(**openai_client)
 85            resp = await single_api_chat_completions(
 86                client,
 87                status_msg,
 88                openai,
 89                params=params,
 90                prefix=prefix,
 91                hide_thinking=hide_thinking,
 92                silent=silent,
 93                max_retries=max_retries,
 94                **kwargs,
 95            )
 96            if resp.get("texts") or resp.get("tool_name"):
 97                resp |= {
 98                    "success": True,
 99                    "prefix": prefix,
100                    "model_name": model_name,
101                    "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
102                }
103                resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
104                return resp
105        except Exception as e:
106            logger.error(f"OpenAI API error: {e}")
107            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
108    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
109
110
111async def single_api_chat_completions(
112    client: Client,
113    status_msg: Message | None,
114    openai: AsyncOpenAI,
115    params: dict,
116    *,
117    prefix: str = "",
118    hide_thinking: bool = False,
119    silent: bool = False,
120    retry: int = 0,
121    max_retries: int = 3,
122    **kwargs,
123) -> dict:
124    """Get OpenAI Chat Completions via single API.
125
126    Returns:
127        dict: {"texts": str, "thoughts": str, "tool_name": str, "tool_args": str, "sent_messages": list[Message]}
128    """
129    if retry > max_retries:
130        return {"texts": "", "thoughts": "", "tool_name": "", "tool_args": "", "sent_messages": []}
131    answers = ""  # all model responses
132    thoughts = ""  # all model thoughts
133    tool_name = ""
134    tool_args = ""
135    runtime_texts = ""  # for a single telegram message
136    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
137    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
138    sent_messages = []
139    resp = ""
140    try:
141        reasoning_chat_flag = None  # 是否是推理模型
142        is_reasoning = False  # 是否正在推理
143        async for chunk in await openai.chat.completions.create(**params):
144            resp = trim_none(chunk.model_dump())
145            logger.trace(resp)
146            chunk_answer = glom(resp, "choices.0.delta.content", default="") or ""
147            chunk_thinking = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
148            tool_name = tool_name or glom(resp, "choices.0.delta.tool_calls.0.function.name", default="")
149            tool_args += glom(resp, "choices.0.delta.tool_calls.0.function.arguments", default="") or ""
150            if not chunk_answer and not chunk_thinking:
151                continue
152            if reasoning_chat_flag is None and chunk_thinking:
153                reasoning_chat_flag = True
154            if chunk_thinking and not is_reasoning:  # 首次收到推理内容
155                is_reasoning = True
156                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
157            elif chunk_thinking and is_reasoning:  # 收到推理内容且正在思考
158                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
159            elif reasoning_chat_flag is True and is_reasoning:  # 收到回答, 关闭推理标志
160                is_reasoning = False
161                runtime_texts = chunk_answer.lstrip()
162            else:
163                runtime_texts += chunk_answer
164
165            thoughts += chunk_thinking
166            answers += chunk_answer
167            if hide_thinking and is_reasoning and not tool_args:
168                continue
169            runtime_texts = beautify_llm_response(runtime_texts)
170            length = await count_without_entities(prefix + runtime_texts)
171            if length <= TEXT_LENGTH - 10:  # leave some flexibility
172                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
173                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
174            else:  # answers is too long, split it into multiple messages
175                parts = await smart_split(prefix + runtime_texts)
176                if len(parts) == 1:
177                    continue
178                if is_reasoning:
179                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
180                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
181                else:
182                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
183                    runtime_texts = parts[-1]  # keep the last part
184                    if not silent:
185                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
186                        sent_messages.append(status_msg)
187                        status_mid = status_msg.id
188        if tool_name.strip():
189            return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
190        # all chunks are processed
191        if not (answers.strip() or thoughts.strip()):  # empty response
192            await modify_progress(message=status_msg, text=str(resp), force_update=True)
193            return await single_api_chat_completions(
194                client,
195                status_msg,
196                openai,
197                params=params,
198                prefix=prefix,
199                retry=retry + 1,
200                max_retries=max_retries,
201                hide_thinking=hide_thinking,
202                silent=silent,
203                **kwargs,
204            )
205
206        if not thoughts:  # no structured thinking in response
207            thoughts, answers = split_reasoning(answers)
208
209        # answers = add_search_results_to_response(config.get("search_results", []), answers)
210        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
211            quoted = answers.strip()
212            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
213        else:  # total length is too long, answers are splitted into multiple messages
214            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
215
216    except Exception as e:
217        error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}\n{resp}"
218        logger.error(error)
219        with contextlib.suppress(Exception):
220            await modify_progress(status_msg, text=error, force_update=True, **kwargs)
221            [await delete_message(msg) for msg in sent_messages]
222        if retry + 1 < max_retries:
223            return await single_api_chat_completions(
224                client,
225                status_msg,
226                openai,
227                params=params,
228                prefix=prefix,
229                retry=retry + 1,
230                max_retries=max_retries,
231                hide_thinking=hide_thinking,
232                silent=silent,
233                **kwargs,
234            )
235    return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
236
237
238def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
239    if not skills:
240        return contexts
241    if glom(contexts, "0.role", default="") != "system":
242        contexts.insert(0, {"role": "system", "content": skills})
243        return contexts
244    system_prompt = contexts[0]["content"]
245    if isinstance(system_prompt, str) and skills not in system_prompt:
246        system_prompt = f"{system_prompt}\n{skills}"
247    if isinstance(system_prompt, list) and {"type": "text", "text": skills} not in system_prompt:
248        system_prompt.append({"type": "text", "text": skills})
249    contexts[0] = {"role": "system", "content": system_prompt}
250    return contexts