main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import hashlib
5from typing import Literal
6
7from glom import Coalesce, flatten, glom
8from loguru import logger
9from openai import AsyncOpenAI, DefaultAsyncHttpxClient
10from pyrogram.client import Client
11from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
12from pyrogram.types import Message, ReplyParameters
13
14from ai.texts.contexts import get_openai_response_contexts
15from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, deep_merge, literal_eval, load_skills, trim_none
16from config import AI, PROXY, TEXT_LENGTH
17from database.r2 import set_cf_r2
18from messages.parser import get_thread_id
19from messages.progress import modify_progress
20from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
21from utils import number_to_emoji, strings_list
22
23
24async def openai_responses_api(
25 client: Client,
26 message: Message,
27 *,
28 prefix: str = "",
29 model_id: str = AI.OPENAI_MODEL_ID,
30 model_name: str = AI.OPENAI_MODEL_ID,
31 openai_base_url: str = AI.OPENAI_BASE_URL,
32 openai_api_keys: str = AI.OPENAI_API_KEYS,
33 openai_client_config: str | dict = "",
34 openai_default_headers: str | dict = "",
35 openai_responses_config: str | dict = "",
36 openai_proxy: str | None = PROXY.OPENAI,
37 cache_response_ttl: int = 0,
38 openai_allow_image: bool = False, # whether to allow image in input modalities
39 openai_allow_video: bool = False, # whether to allow video in input modalities
40 openai_allow_file: bool = False, # whether to allow file in input modalities
41 openai_media_send_as: Literal["base64", "file_id"] = "file_id",
42 skills: str = "",
43 openai_append_tool_results: bool = True,
44 hide_thinking: bool = False,
45 silent: bool = False,
46 max_retries: int = 3,
47 **kwargs,
48) -> dict:
49 """Get OpenAI Chat Completions.
50
51 Returns:
52 dict: {"texts": str, "thoughts": str, "response_id": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
53 """
54 if not prefix:
55 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
56
57 if silent or not kwargs.get("show_progress"): # noqa: SIM108
58 status_msg = None
59 else:
60 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
61
62 sent_messages = [status_msg]
63 cache_day = round(cache_response_ttl // 86400)
64 try:
65 openai_client = {}
66 if literal_eval(openai_client_config):
67 openai_client |= literal_eval(openai_client_config)
68 if literal_eval(openai_default_headers):
69 openai_client = deep_merge(openai_client, {"default_headers": literal_eval(openai_default_headers)})
70 if openai_proxy:
71 openai_client = deep_merge(openai_client, {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)})
72 except Exception as e:
73 logger.error(f"OpenAI client setup error: {e}")
74 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
75
76 for api_key in strings_list(openai_api_keys, shuffle=True):
77 try:
78 openai_client |= {"base_url": openai_base_url, "api_key": api_key}
79 logger.trace(f"AsyncOpenAI(**{openai_client})")
80 openai = AsyncOpenAI(**openai_client)
81 previous_response_id, contexts = await get_openai_response_contexts(
82 client,
83 message,
84 openai_params=openai_client
85 | {
86 "proxy": openai_proxy,
87 "model_id": model_id,
88 "cache_day": cache_day,
89 "allow_image": openai_allow_image,
90 "allow_video": openai_allow_video,
91 "allow_file": openai_allow_file,
92 "openai_media_send_as": openai_media_send_as,
93 },
94 )
95 params = {}
96 params |= {"model": model_id, "stream": True, "input": contexts}
97 if skills:
98 params |= {"instructions": await load_skills(skills)}
99 if literal_eval(openai_responses_config):
100 params |= literal_eval(openai_responses_config)
101 if previous_response_id:
102 params |= {"previous_response_id": previous_response_id}
103 logger.debug(f"openai.responses.create(**{params})")
104 resp = await single_api_response(
105 client,
106 status_msg,
107 openai,
108 params=params,
109 prefix=prefix,
110 hide_thinking=hide_thinking,
111 openai_append_tool_results=openai_append_tool_results,
112 silent=silent,
113 max_retries=max_retries,
114 **kwargs,
115 )
116 if not resp.get("texts"):
117 continue
118 sent_messages.extend(resp.get("sent_messages", []))
119 sent_messages = [m for m in sent_messages if isinstance(m, Message)]
120 if cache_response_ttl > 0:
121 day = round(cache_response_ttl // 86400)
122 for sent_msg in sent_messages: # save the reponse to R2
123 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
124 tid = get_thread_id(sent_msg)
125 await set_cf_r2(
126 f"TTL/{day}d/OpenAI/{model_id}/{key_hash}/{sent_msg.chat.id}/{sent_msg.id}{'/' + str(tid) if tid else ''}",
127 data=resp["full_response"],
128 metadata={"response_id": resp["response_id"]},
129 silent=silent,
130 )
131 return {
132 "success": True,
133 "texts": resp["texts"],
134 "thoughts": resp["thoughts"],
135 "response_id": resp["response_id"],
136 "prefix": prefix,
137 "model_name": model_name,
138 "sent_messages": sent_messages,
139 }
140 except Exception as e:
141 logger.error(f"OpenAI API error: {e}")
142 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
143 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
144
145
146async def single_api_response(
147 client: Client,
148 status_msg: Message | None,
149 openai: AsyncOpenAI,
150 params: dict,
151 *,
152 prefix: str = "",
153 hide_thinking: bool = False,
154 openai_append_tool_results: bool = True,
155 silent: bool = False,
156 retry: int = 0,
157 max_retries: int = 3,
158 **kwargs,
159) -> dict:
160 """Get OpenAI Chat Completions via single API.
161
162 Returns:
163 dict: {"texts": str, "thoughts": str, "full_response":dict, "response_id": str, "sent_messages": list[Message]}
164 """
165 if retry > max_retries:
166 return {"texts": "", "thoughts": "", "sent_messages": []}
167 answers = "" # all model responses
168 thoughts = "" # all model thoughts
169 runtime_texts = "" # for a single telegram message
170 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
171 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
172 sent_messages = []
173 full_response = {}
174 response_id = ""
175 try:
176 tool_calls: list[dict] = [] # tool_call results
177 is_reasoning = False
178 async for chunk in await openai.responses.create(**params):
179 resp = trim_none(chunk.model_dump())
180 logger.trace(resp)
181 error = await parse_error(resp, retry, max_retries, status_msg)
182 if error["retry"]:
183 return await single_api_response(
184 client,
185 status_msg,
186 openai,
187 params=params,
188 prefix=prefix,
189 retry=retry + 1,
190 max_retries=max_retries,
191 hide_thinking=hide_thinking,
192 silent=silent,
193 **kwargs,
194 )
195 if error["error"]:
196 await modify_progress(message=status_msg, text=error["error"], force_update=True, **kwargs)
197 return {}
198 response_type = resp.get("type", "")
199 chunk_answer = resp.get("delta", "") if response_type == "response.output_text.delta" else ""
200 chunk_thinking = resp.get("delta", "") if response_type == "response.reasoning_summary_text.delta" else ""
201
202 # 设置推理标志
203 if response_type in {"response.reasoning_summary_part.added", "response.reasoning_summary_text.delta"}: # 正在推理
204 is_reasoning = True
205 elif response_type in {"response.content_part.added", "response.output_text.delta"}: # 推理结束
206 is_reasoning = False
207
208 if response_type == "response.reasoning_summary_part.added" and len(thoughts) == 0: # 首次收到推理内容
209 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
210 elif chunk_thinking: # 收到推理内容
211 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
212
213 if response_type == "response.content_part.added": # 收到初始回答
214 runtime_texts = chunk_answer.lstrip()
215 else:
216 runtime_texts += chunk_answer
217
218 thoughts += chunk_thinking
219 answers += chunk_answer
220 if hide_thinking and is_reasoning:
221 continue
222
223 runtime_texts = beautify_llm_response(runtime_texts)
224 length = await count_without_entities(prefix + runtime_texts)
225 if length <= TEXT_LENGTH - 10: # leave some flexibility
226 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
227 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
228 else: # answers is too long, split it into multiple messages
229 parts = await smart_split(prefix + runtime_texts)
230 if len(parts) == 1:
231 continue
232 if is_reasoning:
233 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
234 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
235 else:
236 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
237 runtime_texts = parts[-1] # keep the last part
238 if not silent:
239 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
240 sent_messages.append(status_msg)
241 status_mid = status_msg.id
242
243 if response_type == "response.completed":
244 full_response = resp
245 response_id = glom(resp, "response.id", default="")
246 tool_calls = flatten(glom(resp, "response.output.**.annotations", default=[]))
247 thoughts = flatten(glom(resp, "response.output.*.summary.*.text", default=[])) or thoughts
248 thoughts = "".join(thoughts)
249 answers = flatten(glom(resp, "response.output.*.content.*.text", default=[])) or answers
250 answers = "".join(answers)
251
252 # all chunks are processed
253 if not answers.strip() and not thoughts.strip(): # empty response
254 return await single_api_response(
255 client,
256 status_msg,
257 openai,
258 params=params,
259 prefix=prefix,
260 retry=retry + 1,
261 max_retries=max_retries,
262 openai_append_tool_results=openai_append_tool_results,
263 hide_thinking=hide_thinking,
264 silent=silent,
265 **kwargs,
266 )
267 if openai_append_tool_results:
268 answers = add_tool_call_results_to_response(tool_calls, answers)
269 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
270 quoted = answers.strip()
271 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
272 else: # total length is too long, answers are splitted into multiple messages
273 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
274
275 except Exception as e:
276 error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
277 if "resp" in locals():
278 error += f"\n{resp}"
279 logger.error(error)
280 with contextlib.suppress(Exception):
281 await modify_progress(status_msg, text=error, force_update=True, **kwargs)
282 [await delete_message(msg) for msg in sent_messages]
283 if retry + 1 < max_retries:
284 return await single_api_response(
285 client,
286 status_msg,
287 openai,
288 params=params,
289 prefix=prefix,
290 retry=retry + 1,
291 max_retries=max_retries,
292 openai_append_tool_results=openai_append_tool_results,
293 hide_thinking=hide_thinking,
294 silent=silent,
295 **kwargs,
296 )
297 return {"texts": answers, "thoughts": thoughts, "full_response": full_response, "response_id": response_id, "sent_messages": [m for m in sent_messages if isinstance(m, Message)]}
298
299
300async def parse_error(resp: dict, retry: int, max_retries: int, status_msg: Message | None) -> dict:
301 """Parse GPT error.
302
303 Returns:
304 {"error": "msg", "retry": bool}
305 """
306 response_type = glom(resp, "type", default="")
307 if response_type not in {"error", "response.failed"}:
308 return {"error": "", "retry": False}
309 logger.warning(resp)
310 await modify_progress(status_msg, text=f"{resp}\n重试次数: {retry + 1}/{max_retries}", force_update=True)
311 if retry < max_retries:
312 return {"error": str(resp), "retry": True}
313 return {"error": str(resp), "retry": False}
314
315
316def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> str:
317 if not tool_calls or not isinstance(tool_calls, list):
318 return answers
319 answers = answers.strip()
320 for idx, tool_call in enumerate(tool_calls):
321 title = glom(tool_call, Coalesce("title", "site_name"), default="")
322 link = glom(tool_call, Coalesce("url", "link"), default="")
323 if link.startswith("http"):
324 answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
325 return answers.strip()