main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4import hashlib
  5import json
  6from json import JSONDecodeError
  7from typing import Literal
  8
  9from glom import Coalesce, flatten, glom
 10from jsonschema import ValidationError, validate
 11from loguru import logger
 12from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 13from pyrogram.client import Client
 14from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 15from pyrogram.types import Message, ReplyParameters
 16
 17from ai.texts.contexts import get_openai_response_contexts
 18from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
 19from config import AI, PROXY, TEXT_LENGTH
 20from database.r2 import set_cf_r2
 21from messages.progress import modify_progress
 22from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
 23from utils import number_to_emoji, strings_list
 24
 25
 26async def openai_responses_api(
 27    client: Client,
 28    message: Message,
 29    *,
 30    prefix: str = "",
 31    model_id: str = AI.OPENAI_MODEL_ID,
 32    model_name: str = AI.OPENAI_MODEL_ID,
 33    openai_base_url: str = AI.OPENAI_BASE_URL,
 34    openai_api_keys: str = AI.OPENAI_API_KEYS,
 35    openai_client_config: str | dict = "",
 36    openai_default_headers: str | dict = "",
 37    openai_responses_config: str | dict = "",
 38    openai_proxy: str | None = PROXY.OPENAI,
 39    cache_response_ttl: int = 0,
 40    openai_allow_image: bool = True,  # whether to allow image in input modalities
 41    openai_allow_video: bool = False,  # whether to allow video in input modalities
 42    openai_allow_audio: bool = False,  # whether to allow audio in input modalities
 43    openai_allow_file: bool = False,  # whether to allow file in input modalities
 44    openai_media_send_as: Literal["base64", "file_id"] = "base64",
 45    additional_contexts: list[dict] | None = None,  # additional contexts to append to the contexts
 46    skills: str = "",
 47    openai_append_tool_results: bool = True,
 48    hide_thinking: bool = False,
 49    add_sender: bool | None = None,
 50    silent: bool = False,
 51    max_retries: int = 3,
 52    **kwargs,
 53) -> dict:
 54    """Get OpenAI Chat Completions.
 55
 56    Returns:
 57        dict: {"texts": str, "thoughts": str, "response_id": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
 58    """
 59    if not prefix:
 60        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 61
 62    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 63        status_msg = None
 64    else:
 65        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 66
 67    sent_messages = [status_msg]
 68    cache_day = round(cache_response_ttl // 86400)
 69    try:
 70        openai_client = {}
 71        if literal_eval(openai_client_config):
 72            openai_client |= literal_eval(openai_client_config)
 73        if literal_eval(openai_default_headers):
 74            openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
 75        if openai_proxy:
 76            openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
 77    except Exception as e:
 78        logger.error(f"OpenAI client setup error: {e}")
 79        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 80
 81    for api_key in strings_list(openai_api_keys, shuffle=True):
 82        try:
 83            openai_client |= {"base_url": openai_base_url, "api_key": api_key}
 84            logger.trace(f"AsyncOpenAI(**{openai_client})")
 85            openai = AsyncOpenAI(**openai_client)
 86            previous_response_id, contexts = await get_openai_response_contexts(
 87                client,
 88                message,
 89                params=openai_client
 90                | {
 91                    "proxy": openai_proxy,
 92                    "model_id": model_id,
 93                    "cache_day": cache_day,
 94                    "allow_image": openai_allow_image,
 95                    "allow_video": openai_allow_video,
 96                    "allow_audio": openai_allow_audio,
 97                    "allow_file": openai_allow_file,
 98                    "openai_media_send_as": openai_media_send_as,
 99                    "add_sender": add_sender,
100                    "additional_contexts": additional_contexts,
101                },
102            )
103            params = {}
104            params |= {"model": model_id, "stream": True, "input": contexts}
105            if skills:
106                params |= {"instructions": await load_skills(skills)}
107            if literal_eval(openai_responses_config):
108                params |= literal_eval(openai_responses_config)
109            if previous_response_id:
110                params |= {"previous_response_id": previous_response_id}
111            logger.debug(f"openai.responses.create(**{params})")
112            resp = await single_api_response(
113                client,
114                status_msg,
115                openai,
116                params=params,
117                prefix=prefix,
118                hide_thinking=hide_thinking,
119                openai_append_tool_results=openai_append_tool_results,
120                silent=silent,
121                max_retries=max_retries,
122                **kwargs,
123            )
124            if not is_valid_response(resp, glom(params, "text.format.schema", default={})):
125                continue
126            sent_messages.extend(resp.get("sent_messages", []))
127            sent_messages = [m for m in sent_messages if isinstance(m, Message)]
128            if cache_response_ttl > 0:
129                day = round(cache_response_ttl // 86400)
130                for sent_msg in sent_messages:  # save the reponse to R2
131                    key_hash = hashlib.sha256(api_key.encode()).hexdigest()
132                    await set_cf_r2(
133                        f"TTL/{day}d/OpenAI/{sent_msg.chat.id}/{sent_msg.id}/{key_hash}",
134                        data=resp["full_response"],
135                        metadata={"response_id": resp["response_id"]},
136                        silent=silent,
137                    )
138            return {
139                "success": True,
140                "texts": resp["texts"],
141                "thoughts": resp["thoughts"],
142                "response_id": resp["response_id"],
143                "prefix": prefix,
144                "model_name": model_name,
145                "sent_messages": sent_messages,
146            }
147        except Exception as e:
148            logger.error(f"OpenAI API error: {e}")
149            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
150    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
151
152
153async def single_api_response(
154    client: Client,
155    status_msg: Message | None,
156    openai: AsyncOpenAI,
157    params: dict,
158    *,
159    prefix: str = "",
160    hide_thinking: bool = False,
161    openai_append_tool_results: bool = True,
162    silent: bool = False,
163    retry: int = 0,
164    max_retries: int = 3,
165    **kwargs,
166) -> dict:
167    """Get OpenAI Chat Completions via single API.
168
169    Returns:
170        dict: {"texts": str, "thoughts": str, "full_response":dict, "response_id": str, "sent_messages": list[Message]}
171    """
172    if retry > max_retries:
173        return {"texts": "", "thoughts": "", "sent_messages": []}
174    answers = ""  # all model responses
175    thoughts = ""  # all model thoughts
176    runtime_texts = ""  # for a single telegram message
177    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
178    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
179    sent_messages = []
180    full_response = {}
181    response_id = ""
182    try:
183        tool_calls: list[dict] = []  # tool_call results
184        is_reasoning = False
185        async for chunk in await openai.responses.create(**params):
186            resp = trim_none(chunk.model_dump())
187            resp.pop("item_id", None)  # too noisy
188            resp.pop("output_index", None)  # too noisy
189            resp.pop("summary_index", None)  # too noisy
190            logger.trace(resp)
191            error = await parse_error(resp, retry, max_retries, status_msg)
192            if error["retry"]:
193                return await single_api_response(
194                    client,
195                    status_msg,
196                    openai,
197                    params=params,
198                    prefix=prefix,
199                    retry=retry + 1,
200                    max_retries=max_retries,
201                    hide_thinking=hide_thinking,
202                    silent=silent,
203                    **kwargs,
204                )
205            if error["error"]:
206                await modify_progress(message=status_msg, text=error["error"], force_update=True, **kwargs)
207                return {}
208            response_type = resp.get("type", "")
209            chunk_answer = resp.get("delta", "") if response_type == "response.output_text.delta" else ""
210            chunk_thinking = resp.get("delta", "") if response_type == "response.reasoning_summary_text.delta" else ""
211
212            # 设置推理标志
213            if response_type in {"response.reasoning_summary_part.added", "response.reasoning_summary_text.delta"}:  # 正在推理
214                is_reasoning = True
215            elif response_type in {"response.content_part.added", "response.output_text.delta"}:  # 推理结束
216                is_reasoning = False
217
218            if response_type == "response.reasoning_summary_part.added" and len(thoughts) == 0:  # 首次收到推理内容
219                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
220            elif chunk_thinking:  # 收到推理内容
221                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
222
223            if response_type == "response.content_part.added":  # 收到初始回答
224                runtime_texts = chunk_answer.lstrip()
225            else:
226                runtime_texts += chunk_answer
227
228            thoughts += chunk_thinking
229            answers += chunk_answer
230            if hide_thinking and is_reasoning:
231                continue
232
233            runtime_texts = beautify_llm_response(runtime_texts)
234            length = await count_without_entities(prefix + runtime_texts)
235            if length <= TEXT_LENGTH - 10:  # leave some flexibility
236                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
237                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
238            else:  # answers is too long, split it into multiple messages
239                parts = await smart_split(prefix + runtime_texts)
240                if len(parts) == 1:
241                    continue
242                if is_reasoning:
243                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
244                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
245                else:
246                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
247                    runtime_texts = parts[-1]  # keep the last part
248                    if not silent:
249                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
250                        sent_messages.append(status_msg)
251                        status_mid = status_msg.id
252
253            if response_type == "response.completed":
254                full_response = resp
255                response_id = glom(resp, "response.id", default="")
256                tool_calls = flatten(glom(resp, "response.output.**.annotations", default=[]))
257                thoughts = flatten(glom(resp, "response.output.*.summary.*.text", default=[])) or thoughts
258                thoughts = "".join(thoughts)
259                answers = flatten(glom(resp, "response.output.*.content.*.text", default=[])) or answers
260                answers = "".join(answers)
261
262        # all chunks are processed
263        if not answers.strip() and not thoughts.strip():  # empty response
264            return await single_api_response(
265                client,
266                status_msg,
267                openai,
268                params=params,
269                prefix=prefix,
270                retry=retry + 1,
271                max_retries=max_retries,
272                openai_append_tool_results=openai_append_tool_results,
273                hide_thinking=hide_thinking,
274                silent=silent,
275                **kwargs,
276            )
277        if openai_append_tool_results:
278            answers = add_tool_call_results_to_response(tool_calls, answers)
279        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
280            quoted = answers.strip()
281            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
282        else:  # total length is too long, answers are splitted into multiple messages
283            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
284
285    except Exception as e:
286        error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
287        if "resp" in locals():
288            error += f"\n{resp}"
289        logger.error(error)
290        with contextlib.suppress(Exception):
291            await modify_progress(status_msg, text=error, force_update=True, **kwargs)
292            [await delete_message(msg) for msg in sent_messages]
293        if retry + 1 < max_retries:
294            return await single_api_response(
295                client,
296                status_msg,
297                openai,
298                params=params,
299                prefix=prefix,
300                retry=retry + 1,
301                max_retries=max_retries,
302                openai_append_tool_results=openai_append_tool_results,
303                hide_thinking=hide_thinking,
304                silent=silent,
305                **kwargs,
306            )
307        return {}
308    return {"texts": answers, "thoughts": thoughts, "full_response": full_response, "response_id": response_id, "sent_messages": [m for m in sent_messages if isinstance(m, Message)]}
309
310
311async def parse_error(resp: dict, retry: int, max_retries: int, status_msg: Message | None) -> dict:
312    """Parse GPT error.
313
314    Returns:
315        {"error": "msg", "retry": bool}
316    """
317    response_type = glom(resp, "type", default="")
318    if response_type not in {"error", "response.failed"}:
319        return {"error": "", "retry": False}
320    logger.warning(resp)
321    await modify_progress(status_msg, text=f"{resp}\n重试次数: {retry + 1}/{max_retries}", force_update=True)
322    if retry < max_retries:
323        return {"error": str(resp), "retry": True}
324    return {"error": str(resp), "retry": False}
325
326
327def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> str:
328    if not tool_calls or not isinstance(tool_calls, list):
329        return answers
330    answers = answers.strip()
331    for idx, tool_call in enumerate(tool_calls):
332        title = glom(tool_call, Coalesce("title", "site_name"), default="")
333        link = glom(tool_call, Coalesce("url", "link"), default="")
334        if link.startswith("http"):
335            answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
336    return answers.strip()
337
338
339def is_valid_response(resp: dict, schema: dict) -> bool:
340    """Check if the response is valid."""
341    if not schema:
342        return bool(resp.get("texts"))
343    if not resp.get("texts"):
344        return False
345    try:
346        data = json.loads(resp["texts"])
347        validate(instance=data, schema=schema)
348    except (JSONDecodeError, ValidationError) as e:
349        logger.error(f"Invalid JSONSchema response: {e}")
350        return False
351    return True