main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4import hashlib
  5from typing import Literal
  6
  7from glom import Coalesce, flatten, glom
  8from loguru import logger
  9from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 10from pyrogram.client import Client
 11from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 12from pyrogram.types import Message, ReplyParameters
 13
 14from ai.texts.contexts import get_openai_response_contexts
 15from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
 16from config import AI, PROXY, TEXT_LENGTH
 17from database.r2 import set_cf_r2
 18from messages.parser import get_thread_id
 19from messages.progress import modify_progress
 20from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
 21from utils import number_to_emoji, strings_list
 22
 23
 24async def openai_responses_api(
 25    client: Client,
 26    message: Message,
 27    *,
 28    prefix: str = "",
 29    model_id: str = AI.OPENAI_MODEL_ID,
 30    model_name: str = AI.OPENAI_MODEL_ID,
 31    openai_base_url: str = AI.OPENAI_BASE_URL,
 32    openai_api_keys: str = AI.OPENAI_API_KEYS,
 33    openai_client_config: str | dict = "",
 34    openai_default_headers: str | dict = "",
 35    openai_responses_config: str | dict = "",
 36    openai_proxy: str | None = PROXY.OPENAI,
 37    cache_response_ttl: int = 0,
 38    openai_allow_image: bool = False,  # whether to allow image in input modalities
 39    openai_allow_video: bool = False,  # whether to allow video in input modalities
 40    openai_allow_file: bool = False,  # whether to allow file in input modalities
 41    openai_media_send_as: Literal["base64", "file_id"] = "file_id",
 42    skills: str = "",
 43    openai_append_tool_results: bool = True,
 44    hide_thinking: bool = False,
 45    silent: bool = False,
 46    max_retries: int = 3,
 47    **kwargs,
 48) -> dict:
 49    """Get OpenAI Chat Completions.
 50
 51    Returns:
 52        dict: {"texts": str, "thoughts": str, "response_id": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
 53    """
 54    if not prefix:
 55        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 56
 57    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 58        status_msg = None
 59    else:
 60        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 61
 62    sent_messages = [status_msg]
 63    cache_day = round(cache_response_ttl // 86400)
 64    try:
 65        openai_client = {}
 66        if literal_eval(openai_client_config):
 67            openai_client |= literal_eval(openai_client_config)
 68        if literal_eval(openai_default_headers):
 69            openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
 70        if openai_proxy:
 71            openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
 72    except Exception as e:
 73        logger.error(f"OpenAI client setup error: {e}")
 74        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 75
 76    for api_key in strings_list(openai_api_keys, shuffle=True):
 77        try:
 78            openai_client |= {"base_url": openai_base_url, "api_key": api_key}
 79            logger.trace(f"AsyncOpenAI(**{openai_client})")
 80            openai = AsyncOpenAI(**openai_client)
 81            previous_response_id, contexts = await get_openai_response_contexts(
 82                client,
 83                message,
 84                openai_params=openai_client
 85                | {
 86                    "proxy": openai_proxy,
 87                    "model_id": model_id,
 88                    "cache_day": cache_day,
 89                    "allow_image": openai_allow_image,
 90                    "allow_video": openai_allow_video,
 91                    "allow_file": openai_allow_file,
 92                    "openai_media_send_as": openai_media_send_as,
 93                },
 94            )
 95            params = {}
 96            params |= {"model": model_id, "stream": True, "input": contexts}
 97            if skills:
 98                params |= {"instructions": await load_skills(skills)}
 99            if literal_eval(openai_responses_config):
100                params |= literal_eval(openai_responses_config)
101            if previous_response_id:
102                params |= {"previous_response_id": previous_response_id}
103            logger.debug(f"openai.responses.create(**{params})")
104            resp = await single_api_response(
105                client,
106                status_msg,
107                openai,
108                params=params,
109                prefix=prefix,
110                hide_thinking=hide_thinking,
111                openai_append_tool_results=openai_append_tool_results,
112                silent=silent,
113                max_retries=max_retries,
114                **kwargs,
115            )
116            if not resp.get("texts"):
117                continue
118            sent_messages.extend(resp.get("sent_messages", []))
119            sent_messages = [m for m in sent_messages if isinstance(m, Message)]
120            if cache_response_ttl > 0:
121                day = round(cache_response_ttl // 86400)
122                for sent_msg in sent_messages:  # save the reponse to R2
123                    key_hash = hashlib.sha256(api_key.encode()).hexdigest()
124                    tid = get_thread_id(sent_msg)
125                    await set_cf_r2(
126                        f"TTL/{day}d/OpenAI/{model_id}/{key_hash}/{sent_msg.chat.id}/{sent_msg.id}{'/' + str(tid) if tid else ''}",
127                        data=resp["full_response"],
128                        metadata={"response_id": resp["response_id"]},
129                        silent=silent,
130                    )
131            return {
132                "success": True,
133                "texts": resp["texts"],
134                "thoughts": resp["thoughts"],
135                "response_id": resp["response_id"],
136                "prefix": prefix,
137                "model_name": model_name,
138                "sent_messages": sent_messages,
139            }
140        except Exception as e:
141            logger.error(f"OpenAI API error: {e}")
142            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
143    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
144
145
146async def single_api_response(
147    client: Client,
148    status_msg: Message | None,
149    openai: AsyncOpenAI,
150    params: dict,
151    *,
152    prefix: str = "",
153    hide_thinking: bool = False,
154    openai_append_tool_results: bool = True,
155    silent: bool = False,
156    retry: int = 0,
157    max_retries: int = 3,
158    **kwargs,
159) -> dict:
160    """Get OpenAI Chat Completions via single API.
161
162    Returns:
163        dict: {"texts": str, "thoughts": str, "full_response":dict, "response_id": str, "sent_messages": list[Message]}
164    """
165    if retry > max_retries:
166        return {"texts": "", "thoughts": "", "sent_messages": []}
167    answers = ""  # all model responses
168    thoughts = ""  # all model thoughts
169    runtime_texts = ""  # for a single telegram message
170    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
171    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
172    sent_messages = []
173    full_response = {}
174    response_id = ""
175    try:
176        tool_calls: list[dict] = []  # tool_call results
177        is_reasoning = False
178        async for chunk in await openai.responses.create(**params):
179            resp = trim_none(chunk.model_dump())
180            logger.trace(resp)
181            error = await parse_error(resp, retry, max_retries, status_msg)
182            if error["retry"]:
183                return await single_api_response(
184                    client,
185                    status_msg,
186                    openai,
187                    params=params,
188                    prefix=prefix,
189                    retry=retry + 1,
190                    max_retries=max_retries,
191                    hide_thinking=hide_thinking,
192                    silent=silent,
193                    **kwargs,
194                )
195            if error["error"]:
196                await modify_progress(message=status_msg, text=error["error"], force_update=True, **kwargs)
197                return {}
198            response_type = resp.get("type", "")
199            chunk_answer = resp.get("delta", "") if response_type == "response.output_text.delta" else ""
200            chunk_thinking = resp.get("delta", "") if response_type == "response.reasoning_summary_text.delta" else ""
201
202            # 设置推理标志
203            if response_type in {"response.reasoning_summary_part.added", "response.reasoning_summary_text.delta"}:  # 正在推理
204                is_reasoning = True
205            elif response_type in {"response.content_part.added", "response.output_text.delta"}:  # 推理结束
206                is_reasoning = False
207
208            if response_type == "response.reasoning_summary_part.added" and len(thoughts) == 0:  # 首次收到推理内容
209                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
210            elif chunk_thinking:  # 收到推理内容
211                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
212
213            if response_type == "response.content_part.added":  # 收到初始回答
214                runtime_texts = chunk_answer.lstrip()
215            else:
216                runtime_texts += chunk_answer
217
218            thoughts += chunk_thinking
219            answers += chunk_answer
220            if hide_thinking and is_reasoning:
221                continue
222
223            runtime_texts = beautify_llm_response(runtime_texts)
224            length = await count_without_entities(prefix + runtime_texts)
225            if length <= TEXT_LENGTH - 10:  # leave some flexibility
226                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
227                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
228            else:  # answers is too long, split it into multiple messages
229                parts = await smart_split(prefix + runtime_texts)
230                if len(parts) == 1:
231                    continue
232                if is_reasoning:
233                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
234                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
235                else:
236                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
237                    runtime_texts = parts[-1]  # keep the last part
238                    if not silent:
239                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
240                        sent_messages.append(status_msg)
241                        status_mid = status_msg.id
242
243            if response_type == "response.completed":
244                full_response = resp
245                response_id = glom(resp, "response.id", default="")
246                tool_calls = flatten(glom(resp, "response.output.**.annotations", default=[]))
247                thoughts = flatten(glom(resp, "response.output.*.summary.*.text", default=[])) or thoughts
248                thoughts = "".join(thoughts)
249                answers = flatten(glom(resp, "response.output.*.content.*.text", default=[])) or answers
250                answers = "".join(answers)
251
252        # all chunks are processed
253        if not answers.strip() and not thoughts.strip():  # empty response
254            return await single_api_response(
255                client,
256                status_msg,
257                openai,
258                params=params,
259                prefix=prefix,
260                retry=retry + 1,
261                max_retries=max_retries,
262                openai_append_tool_results=openai_append_tool_results,
263                hide_thinking=hide_thinking,
264                silent=silent,
265                **kwargs,
266            )
267        if openai_append_tool_results:
268            answers = add_tool_call_results_to_response(tool_calls, answers)
269        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
270            quoted = answers.strip()
271            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
272        else:  # total length is too long, answers are splitted into multiple messages
273            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
274
275    except Exception as e:
276        error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
277        if "resp" in locals():
278            error += f"\n{resp}"
279        logger.error(error)
280        with contextlib.suppress(Exception):
281            await modify_progress(status_msg, text=error, force_update=True, **kwargs)
282            [await delete_message(msg) for msg in sent_messages]
283        if retry + 1 < max_retries:
284            return await single_api_response(
285                client,
286                status_msg,
287                openai,
288                params=params,
289                prefix=prefix,
290                retry=retry + 1,
291                max_retries=max_retries,
292                openai_append_tool_results=openai_append_tool_results,
293                hide_thinking=hide_thinking,
294                silent=silent,
295                **kwargs,
296            )
297    return {"texts": answers, "thoughts": thoughts, "full_response": full_response, "response_id": response_id, "sent_messages": [m for m in sent_messages if isinstance(m, Message)]}
298
299
300async def parse_error(resp: dict, retry: int, max_retries: int, status_msg: Message | None) -> dict:
301    """Parse GPT error.
302
303    Returns:
304        {"error": "msg", "retry": bool}
305    """
306    response_type = glom(resp, "type", default="")
307    if response_type not in {"error", "response.failed"}:
308        return {"error": "", "retry": False}
309    logger.warning(resp)
310    await modify_progress(status_msg, text=f"{resp}\n重试次数: {retry + 1}/{max_retries}", force_update=True)
311    if retry < max_retries:
312        return {"error": str(resp), "retry": True}
313    return {"error": str(resp), "retry": False}
314
315
316def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> str:
317    if not tool_calls or not isinstance(tool_calls, list):
318        return answers
319    answers = answers.strip()
320    for idx, tool_call in enumerate(tool_calls):
321        title = glom(tool_call, Coalesce("title", "site_name"), default="")
322        link = glom(tool_call, Coalesce("url", "link"), default="")
323        if link.startswith("http"):
324            answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
325    return answers.strip()