main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4import json
  5from json import JSONDecodeError
  6from typing import Literal
  7
  8from glom import glom
  9from jsonschema import ValidationError, validate
 10from loguru import logger
 11from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 12from pyrogram.client import Client
 13from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 14from pyrogram.types import Message, ReplyParameters
 15
 16from ai.texts.contexts import get_openai_completion_contexts
 17from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, split_reasoning, trim_none
 18from config import AI, PROXY, TEXT_LENGTH
 19from messages.progress import modify_progress
 20from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
 21from utils import strings_list
 22
 23
 24async def openai_chat_completions(
 25    client: Client,
 26    message: Message,
 27    *,
 28    prefix: str = "",
 29    model_id: str = AI.OPENAI_MODEL_ID,
 30    model_name: str = AI.OPENAI_MODEL_ID,
 31    openai_base_url: str = AI.OPENAI_BASE_URL,
 32    openai_api_keys: str = AI.OPENAI_API_KEYS,
 33    openai_client_config: str | dict = "",
 34    openai_default_headers: str | dict = "",
 35    openai_completions_config: str | dict = "",
 36    openai_proxy: str | None = PROXY.OPENAI,
 37    openai_system_prompt: str = "",
 38    openai_contexts: list[dict] | None = None,
 39    openai_tools: list[dict] | None = None,
 40    skills: str = "",
 41    openai_allow_image: bool = True,  # whether to allow image in input modalities
 42    openai_allow_video: bool = False,  # whether to allow video in input modalities
 43    openai_allow_audio: bool = False,  # whether to allow audio in input modalities
 44    openai_allow_file: bool = False,  # whether to allow file in input modalities
 45    openai_media_send_as: Literal["base64", "file_id"] = "base64",
 46    additional_contexts: list[dict] | None = None,  # additional contexts to append to the contexts
 47    hide_thinking: bool = False,
 48    add_sender: bool | None = None,
 49    silent: bool = False,
 50    max_retries: int = 3,
 51    **kwargs,
 52) -> dict:
 53    """Get OpenAI Chat Completions.
 54
 55    Returns:
 56        dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str,  "sent_messages": list[Message]}
 57    """
 58    if not prefix:
 59        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 60
 61    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 62        status_msg = None
 63    else:
 64        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 65
 66    sent_messages = [status_msg]
 67    try:
 68        openai_client = {}
 69        if literal_eval(openai_client_config):
 70            openai_client |= literal_eval(openai_client_config)
 71        if literal_eval(openai_default_headers):
 72            openai_client |= {"default_headers": literal_eval(openai_default_headers)}
 73        if openai_proxy:
 74            openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
 75
 76        contexts = openai_contexts or await get_openai_completion_contexts(
 77            client,
 78            message,
 79            params={
 80                "add_sender": add_sender,
 81                "allow_image": openai_allow_image,
 82                "allow_video": openai_allow_video,
 83                "allow_audio": openai_allow_audio,
 84                "allow_file": openai_allow_file,
 85                "openai_media_send_as": openai_media_send_as,
 86                "additional_contexts": additional_contexts,
 87            },
 88        )
 89        if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
 90            contexts.insert(0, {"role": "system", "content": openai_system_prompt})
 91        if skills:
 92            contexts = inject_skills(contexts, skills=await load_skills(skills))
 93        params = {"model": model_id, "messages": contexts, "stream": True}
 94        if literal_eval(openai_completions_config):
 95            params |= literal_eval(openai_completions_config)
 96        if openai_tools:
 97            params |= {"tools": openai_tools, "tool_choice": "auto"}
 98        logger.debug(f"openai.chat.completions.create(**{params})")
 99    except Exception as e:
100        logger.error(f"OpenAI client setup error: {e}")
101        await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
102        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
103
104    for api_key in strings_list(openai_api_keys, shuffle=True):
105        try:
106            openai_client |= {"base_url": openai_base_url, "api_key": api_key}
107            logger.trace(f"AsyncOpenAI(**{openai_client})")
108            openai = AsyncOpenAI(**openai_client)
109            resp = await single_api_chat_completions(
110                client,
111                status_msg,
112                openai,
113                params=params,
114                prefix=prefix,
115                hide_thinking=hide_thinking,
116                silent=silent,
117                max_retries=max_retries,
118                **kwargs,
119            )
120            if not is_valid_response(resp, glom(params, "response_format.json_schema.schema", default={})):
121                continue
122            if resp.get("texts") or resp.get("tool_name"):
123                resp |= {
124                    "success": True,
125                    "prefix": prefix,
126                    "model_name": model_name,
127                    "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
128                }
129                resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
130                return resp
131        except Exception as e:
132            logger.error(f"OpenAI API error: {e}")
133            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
134    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
135
136
137async def single_api_chat_completions(
138    client: Client,
139    status_msg: Message | None,
140    openai: AsyncOpenAI,
141    params: dict,
142    *,
143    prefix: str = "",
144    hide_thinking: bool = False,
145    silent: bool = False,
146    retry: int = 0,
147    max_retries: int = 3,
148    **kwargs,
149) -> dict:
150    """Get OpenAI Chat Completions via single API.
151
152    Returns:
153        dict: {"texts": str, "thoughts": str, "tool_name": str, "tool_args": str, "sent_messages": list[Message]}
154    """
155    if retry > max_retries:
156        return {"texts": "", "thoughts": "", "tool_name": "", "tool_args": "", "sent_messages": []}
157    answers = ""  # all model responses
158    thoughts = ""  # all model thoughts
159    tool_name = ""
160    tool_args = ""
161    runtime_texts = ""  # for a single telegram message
162    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
163    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
164    sent_messages = []
165    resp = ""
166    try:
167        reasoning_chat_flag = None  # 是否是推理模型
168        is_reasoning = False  # 是否正在推理
169        async for chunk in await openai.chat.completions.create(**params):
170            resp = trim_none(chunk.model_dump())
171            logger.trace(resp)
172            chunk_answer = glom(resp, "choices.0.delta.content", default="") or ""
173            chunk_thinking = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
174            tool_name = tool_name or glom(resp, "choices.0.delta.tool_calls.0.function.name", default="")
175            tool_args += glom(resp, "choices.0.delta.tool_calls.0.function.arguments", default="") or ""
176            if not chunk_answer and not chunk_thinking:
177                continue
178            if reasoning_chat_flag is None and chunk_thinking:
179                reasoning_chat_flag = True
180            if chunk_thinking and not is_reasoning:  # 首次收到推理内容
181                is_reasoning = True
182                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
183            elif chunk_thinking and is_reasoning:  # 收到推理内容且正在思考
184                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
185            elif reasoning_chat_flag is True and is_reasoning:  # 收到回答, 关闭推理标志
186                is_reasoning = False
187                runtime_texts = chunk_answer.lstrip()
188            else:
189                runtime_texts += chunk_answer
190
191            thoughts += chunk_thinking
192            answers += chunk_answer
193            if hide_thinking and is_reasoning and not tool_args:
194                continue
195            runtime_texts = beautify_llm_response(runtime_texts)
196            length = await count_without_entities(prefix + runtime_texts)
197            if length <= TEXT_LENGTH - 10:  # leave some flexibility
198                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
199                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
200            else:  # answers is too long, split it into multiple messages
201                parts = await smart_split(prefix + runtime_texts)
202                if len(parts) == 1:
203                    continue
204                if is_reasoning:
205                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
206                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
207                else:
208                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
209                    runtime_texts = parts[-1]  # keep the last part
210                    if not silent:
211                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
212                        sent_messages.append(status_msg)
213                        status_mid = status_msg.id
214        if tool_name.strip():
215            return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
216        # all chunks are processed
217        if not (answers.strip() or thoughts.strip()):  # empty response
218            await modify_progress(message=status_msg, text=str(resp), force_update=True)
219            return await single_api_chat_completions(
220                client,
221                status_msg,
222                openai,
223                params=params,
224                prefix=prefix,
225                retry=retry + 1,
226                max_retries=max_retries,
227                hide_thinking=hide_thinking,
228                silent=silent,
229                **kwargs,
230            )
231
232        if not thoughts:  # no structured thinking in response
233            thoughts, answers = split_reasoning(answers)
234
235        # answers = add_search_results_to_response(config.get("search_results", []), answers)
236        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
237            quoted = answers.strip()
238            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
239        else:  # total length is too long, answers are splitted into multiple messages
240            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
241
242    except Exception as e:
243        error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}\n{resp}"
244        logger.error(error)
245        with contextlib.suppress(Exception):
246            await modify_progress(status_msg, text=error, force_update=True, **kwargs)
247            [await delete_message(msg) for msg in sent_messages]
248        if retry + 1 < max_retries:
249            return await single_api_chat_completions(
250                client,
251                status_msg,
252                openai,
253                params=params,
254                prefix=prefix,
255                retry=retry + 1,
256                max_retries=max_retries,
257                hide_thinking=hide_thinking,
258                silent=silent,
259                **kwargs,
260            )
261        return {}
262    return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
263
264
265def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
266    if not skills:
267        return contexts
268    if glom(contexts, "0.role", default="") != "system":
269        contexts.insert(0, {"role": "system", "content": skills})
270        return contexts
271    system_prompt = contexts[0]["content"]
272    if isinstance(system_prompt, str) and skills not in system_prompt:
273        system_prompt = f"{system_prompt}\n{skills}"
274    if isinstance(system_prompt, list) and {"type": "text", "text": skills} not in system_prompt:
275        system_prompt.append({"type": "text", "text": skills})
276    contexts[0] = {"role": "system", "content": system_prompt}
277    return contexts
278
279
280def is_valid_response(resp: dict, schema: dict) -> bool:
281    """Check if the response is valid."""
282    if not schema:
283        return bool(resp.get("texts"))
284    if not resp.get("texts"):
285        return False
286    try:
287        data = json.loads(resp["texts"])
288        validate(instance=data, schema=schema)
289    except (JSONDecodeError, ValidationError) as e:
290        logger.error(f"Invalid JSONSchema response: {e}")
291        return False
292    return True