main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import hashlib
5import json
6from json import JSONDecodeError
7from typing import Literal
8
9from glom import Coalesce, flatten, glom
10from jsonschema import ValidationError, validate
11from loguru import logger
12from openai import AsyncOpenAI, DefaultAsyncHttpxClient
13from pyrogram.client import Client
14from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
15from pyrogram.types import Message, ReplyParameters
16
17from ai.texts.contexts import get_openai_response_contexts
18from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
19from config import AI, PROXY, TEXT_LENGTH
20from database.r2 import set_cf_r2
21from messages.progress import modify_progress
22from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
23from utils import number_to_emoji, strings_list
24
25
26async def openai_responses_api(
27 client: Client,
28 message: Message,
29 *,
30 prefix: str = "",
31 model_id: str = AI.OPENAI_MODEL_ID,
32 model_name: str = AI.OPENAI_MODEL_ID,
33 openai_base_url: str = AI.OPENAI_BASE_URL,
34 openai_api_keys: str = AI.OPENAI_API_KEYS,
35 openai_client_config: str | dict = "",
36 openai_default_headers: str | dict = "",
37 openai_responses_config: str | dict = "",
38 openai_proxy: str | None = PROXY.OPENAI,
39 cache_response_ttl: int = 0,
40 openai_allow_image: bool = True, # whether to allow image in input modalities
41 openai_allow_video: bool = False, # whether to allow video in input modalities
42 openai_allow_audio: bool = False, # whether to allow audio in input modalities
43 openai_allow_file: bool = False, # whether to allow file in input modalities
44 openai_media_send_as: Literal["base64", "file_id"] = "base64",
45 additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
46 skills: str = "",
47 openai_append_tool_results: bool = True,
48 hide_thinking: bool = False,
49 add_sender: bool | None = None,
50 silent: bool = False,
51 max_retries: int = 3,
52 **kwargs,
53) -> dict:
54 """Get OpenAI Chat Completions.
55
56 Returns:
57 dict: {"texts": str, "thoughts": str, "response_id": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
58 """
59 if not prefix:
60 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
61
62 if silent or not kwargs.get("show_progress"): # noqa: SIM108
63 status_msg = None
64 else:
65 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
66
67 sent_messages = [status_msg]
68 cache_day = round(cache_response_ttl // 86400)
69 try:
70 openai_client = {}
71 if literal_eval(openai_client_config):
72 openai_client |= literal_eval(openai_client_config)
73 if literal_eval(openai_default_headers):
74 openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
75 if openai_proxy:
76 openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
77 except Exception as e:
78 logger.error(f"OpenAI client setup error: {e}")
79 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
80
81 for api_key in strings_list(openai_api_keys, shuffle=True):
82 try:
83 openai_client |= {"base_url": openai_base_url, "api_key": api_key}
84 logger.trace(f"AsyncOpenAI(**{openai_client})")
85 openai = AsyncOpenAI(**openai_client)
86 previous_response_id, contexts = await get_openai_response_contexts(
87 client,
88 message,
89 params=openai_client
90 | {
91 "proxy": openai_proxy,
92 "model_id": model_id,
93 "cache_day": cache_day,
94 "allow_image": openai_allow_image,
95 "allow_video": openai_allow_video,
96 "allow_audio": openai_allow_audio,
97 "allow_file": openai_allow_file,
98 "openai_media_send_as": openai_media_send_as,
99 "add_sender": add_sender,
100 "additional_contexts": additional_contexts,
101 },
102 )
103 params = {}
104 params |= {"model": model_id, "stream": True, "input": contexts}
105 if skills:
106 params |= {"instructions": await load_skills(skills)}
107 if literal_eval(openai_responses_config):
108 params |= literal_eval(openai_responses_config)
109 if previous_response_id:
110 params |= {"previous_response_id": previous_response_id}
111 logger.debug(f"openai.responses.create(**{params})")
112 resp = await single_api_response(
113 client,
114 status_msg,
115 openai,
116 params=params,
117 prefix=prefix,
118 hide_thinking=hide_thinking,
119 openai_append_tool_results=openai_append_tool_results,
120 silent=silent,
121 max_retries=max_retries,
122 **kwargs,
123 )
124 if not is_valid_response(resp, glom(params, "text.format.schema", default={})):
125 continue
126 sent_messages.extend(resp.get("sent_messages", []))
127 sent_messages = [m for m in sent_messages if isinstance(m, Message)]
128 if cache_response_ttl > 0:
129 day = round(cache_response_ttl // 86400)
130 for sent_msg in sent_messages: # save the reponse to R2
131 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
132 await set_cf_r2(
133 f"TTL/{day}d/OpenAI/{sent_msg.chat.id}/{sent_msg.id}/{key_hash}",
134 data=resp["full_response"],
135 metadata={"response_id": resp["response_id"]},
136 silent=silent,
137 )
138 return {
139 "success": True,
140 "texts": resp["texts"],
141 "thoughts": resp["thoughts"],
142 "response_id": resp["response_id"],
143 "prefix": prefix,
144 "model_name": model_name,
145 "sent_messages": sent_messages,
146 }
147 except Exception as e:
148 logger.error(f"OpenAI API error: {e}")
149 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
150 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
151
152
153async def single_api_response(
154 client: Client,
155 status_msg: Message | None,
156 openai: AsyncOpenAI,
157 params: dict,
158 *,
159 prefix: str = "",
160 hide_thinking: bool = False,
161 openai_append_tool_results: bool = True,
162 silent: bool = False,
163 retry: int = 0,
164 max_retries: int = 3,
165 **kwargs,
166) -> dict:
167 """Get OpenAI Chat Completions via single API.
168
169 Returns:
170 dict: {"texts": str, "thoughts": str, "full_response":dict, "response_id": str, "sent_messages": list[Message]}
171 """
172 if retry > max_retries:
173 return {"texts": "", "thoughts": "", "sent_messages": []}
174 answers = "" # all model responses
175 thoughts = "" # all model thoughts
176 runtime_texts = "" # for a single telegram message
177 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
178 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
179 sent_messages = []
180 full_response = {}
181 response_id = ""
182 try:
183 tool_calls: list[dict] = [] # tool_call results
184 is_reasoning = False
185 async for chunk in await openai.responses.create(**params):
186 resp = trim_none(chunk.model_dump())
187 resp.pop("item_id", None) # too noisy
188 resp.pop("output_index", None) # too noisy
189 resp.pop("summary_index", None) # too noisy
190 logger.trace(resp)
191 error = await parse_error(resp, retry, max_retries, status_msg)
192 if error["retry"]:
193 return await single_api_response(
194 client,
195 status_msg,
196 openai,
197 params=params,
198 prefix=prefix,
199 retry=retry + 1,
200 max_retries=max_retries,
201 hide_thinking=hide_thinking,
202 silent=silent,
203 **kwargs,
204 )
205 if error["error"]:
206 await modify_progress(message=status_msg, text=error["error"], force_update=True, **kwargs)
207 return {}
208 response_type = resp.get("type", "")
209 chunk_answer = resp.get("delta", "") if response_type == "response.output_text.delta" else ""
210 chunk_thinking = resp.get("delta", "") if response_type == "response.reasoning_summary_text.delta" else ""
211
212 # 设置推理标志
213 if response_type in {"response.reasoning_summary_part.added", "response.reasoning_summary_text.delta"}: # 正在推理
214 is_reasoning = True
215 elif response_type in {"response.content_part.added", "response.output_text.delta"}: # 推理结束
216 is_reasoning = False
217
218 if response_type == "response.reasoning_summary_part.added" and len(thoughts) == 0: # 首次收到推理内容
219 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
220 elif chunk_thinking: # 收到推理内容
221 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
222
223 if response_type == "response.content_part.added": # 收到初始回答
224 runtime_texts = chunk_answer.lstrip()
225 else:
226 runtime_texts += chunk_answer
227
228 thoughts += chunk_thinking
229 answers += chunk_answer
230 if hide_thinking and is_reasoning:
231 continue
232
233 runtime_texts = beautify_llm_response(runtime_texts)
234 length = await count_without_entities(prefix + runtime_texts)
235 if length <= TEXT_LENGTH - 10: # leave some flexibility
236 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
237 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
238 else: # answers is too long, split it into multiple messages
239 parts = await smart_split(prefix + runtime_texts)
240 if len(parts) == 1:
241 continue
242 if is_reasoning:
243 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
244 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
245 else:
246 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
247 runtime_texts = parts[-1] # keep the last part
248 if not silent:
249 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
250 sent_messages.append(status_msg)
251 status_mid = status_msg.id
252
253 if response_type == "response.completed":
254 full_response = resp
255 response_id = glom(resp, "response.id", default="")
256 tool_calls = flatten(glom(resp, "response.output.**.annotations", default=[]))
257 thoughts = flatten(glom(resp, "response.output.*.summary.*.text", default=[])) or thoughts
258 thoughts = "".join(thoughts)
259 answers = flatten(glom(resp, "response.output.*.content.*.text", default=[])) or answers
260 answers = "".join(answers)
261
262 # all chunks are processed
263 if not answers.strip() and not thoughts.strip(): # empty response
264 return await single_api_response(
265 client,
266 status_msg,
267 openai,
268 params=params,
269 prefix=prefix,
270 retry=retry + 1,
271 max_retries=max_retries,
272 openai_append_tool_results=openai_append_tool_results,
273 hide_thinking=hide_thinking,
274 silent=silent,
275 **kwargs,
276 )
277 if openai_append_tool_results:
278 answers = add_tool_call_results_to_response(tool_calls, answers)
279 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
280 quoted = answers.strip()
281 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
282 else: # total length is too long, answers are splitted into multiple messages
283 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
284
285 except Exception as e:
286 error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
287 if "resp" in locals():
288 error += f"\n{resp}"
289 logger.error(error)
290 with contextlib.suppress(Exception):
291 await modify_progress(status_msg, text=error, force_update=True, **kwargs)
292 [await delete_message(msg) for msg in sent_messages]
293 if retry + 1 < max_retries:
294 return await single_api_response(
295 client,
296 status_msg,
297 openai,
298 params=params,
299 prefix=prefix,
300 retry=retry + 1,
301 max_retries=max_retries,
302 openai_append_tool_results=openai_append_tool_results,
303 hide_thinking=hide_thinking,
304 silent=silent,
305 **kwargs,
306 )
307 return {}
308 return {"texts": answers, "thoughts": thoughts, "full_response": full_response, "response_id": response_id, "sent_messages": [m for m in sent_messages if isinstance(m, Message)]}
309
310
311async def parse_error(resp: dict, retry: int, max_retries: int, status_msg: Message | None) -> dict:
312 """Parse GPT error.
313
314 Returns:
315 {"error": "msg", "retry": bool}
316 """
317 response_type = glom(resp, "type", default="")
318 if response_type not in {"error", "response.failed"}:
319 return {"error": "", "retry": False}
320 logger.warning(resp)
321 await modify_progress(status_msg, text=f"{resp}\n重试次数: {retry + 1}/{max_retries}", force_update=True)
322 if retry < max_retries:
323 return {"error": str(resp), "retry": True}
324 return {"error": str(resp), "retry": False}
325
326
327def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> str:
328 if not tool_calls or not isinstance(tool_calls, list):
329 return answers
330 answers = answers.strip()
331 for idx, tool_call in enumerate(tool_calls):
332 title = glom(tool_call, Coalesce("title", "site_name"), default="")
333 link = glom(tool_call, Coalesce("url", "link"), default="")
334 if link.startswith("http"):
335 answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
336 return answers.strip()
337
338
339def is_valid_response(resp: dict, schema: dict) -> bool:
340 """Check if the response is valid."""
341 if not schema:
342 return bool(resp.get("texts"))
343 if not resp.get("texts"):
344 return False
345 try:
346 data = json.loads(resp["texts"])
347 validate(instance=data, schema=schema)
348 except (JSONDecodeError, ValidationError) as e:
349 logger.error(f"Invalid JSONSchema response: {e}")
350 return False
351 return True