main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import json
5from json import JSONDecodeError
6from typing import Literal
7
8from glom import glom
9from jsonschema import ValidationError, validate
10from loguru import logger
11from openai import AsyncOpenAI, DefaultAsyncHttpxClient
12from pyrogram.client import Client
13from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
14from pyrogram.types import Message, ReplyParameters
15
16from ai.texts.contexts import get_openai_completion_contexts
17from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, split_reasoning, trim_none
18from config import AI, PROXY, TEXT_LENGTH
19from messages.progress import modify_progress
20from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
21from utils import strings_list
22
23
24async def openai_chat_completions(
25 client: Client,
26 message: Message,
27 *,
28 prefix: str = "",
29 model_id: str = AI.OPENAI_MODEL_ID,
30 model_name: str = AI.OPENAI_MODEL_ID,
31 openai_base_url: str = AI.OPENAI_BASE_URL,
32 openai_api_keys: str = AI.OPENAI_API_KEYS,
33 openai_client_config: str | dict = "",
34 openai_default_headers: str | dict = "",
35 openai_completions_config: str | dict = "",
36 openai_proxy: str | None = PROXY.OPENAI,
37 openai_system_prompt: str = "",
38 openai_contexts: list[dict] | None = None,
39 openai_tools: list[dict] | None = None,
40 skills: str = "",
41 openai_allow_image: bool = True, # whether to allow image in input modalities
42 openai_allow_video: bool = False, # whether to allow video in input modalities
43 openai_allow_audio: bool = False, # whether to allow audio in input modalities
44 openai_allow_file: bool = False, # whether to allow file in input modalities
45 openai_media_send_as: Literal["base64", "file_id"] = "base64",
46 additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
47 hide_thinking: bool = False,
48 add_sender: bool | None = None,
49 silent: bool = False,
50 max_retries: int = 3,
51 **kwargs,
52) -> dict:
53 """Get OpenAI Chat Completions.
54
55 Returns:
56 dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
57 """
58 if not prefix:
59 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
60
61 if silent or not kwargs.get("show_progress"): # noqa: SIM108
62 status_msg = None
63 else:
64 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
65
66 sent_messages = [status_msg]
67 try:
68 openai_client = {}
69 if literal_eval(openai_client_config):
70 openai_client |= literal_eval(openai_client_config)
71 if literal_eval(openai_default_headers):
72 openai_client |= {"default_headers": literal_eval(openai_default_headers)}
73 if openai_proxy:
74 openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
75
76 contexts = openai_contexts or await get_openai_completion_contexts(
77 client,
78 message,
79 params={
80 "add_sender": add_sender,
81 "allow_image": openai_allow_image,
82 "allow_video": openai_allow_video,
83 "allow_audio": openai_allow_audio,
84 "allow_file": openai_allow_file,
85 "openai_media_send_as": openai_media_send_as,
86 "additional_contexts": additional_contexts,
87 },
88 )
89 if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
90 contexts.insert(0, {"role": "system", "content": openai_system_prompt})
91 if skills:
92 contexts = inject_skills(contexts, skills=await load_skills(skills))
93 params = {"model": model_id, "messages": contexts, "stream": True}
94 if literal_eval(openai_completions_config):
95 params |= literal_eval(openai_completions_config)
96 if openai_tools:
97 params |= {"tools": openai_tools, "tool_choice": "auto"}
98 logger.debug(f"openai.chat.completions.create(**{params})")
99 except Exception as e:
100 logger.error(f"OpenAI client setup error: {e}")
101 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
102 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
103
104 for api_key in strings_list(openai_api_keys, shuffle=True):
105 try:
106 openai_client |= {"base_url": openai_base_url, "api_key": api_key}
107 logger.trace(f"AsyncOpenAI(**{openai_client})")
108 openai = AsyncOpenAI(**openai_client)
109 resp = await single_api_chat_completions(
110 client,
111 status_msg,
112 openai,
113 params=params,
114 prefix=prefix,
115 hide_thinking=hide_thinking,
116 silent=silent,
117 max_retries=max_retries,
118 **kwargs,
119 )
120 if not is_valid_response(resp, glom(params, "response_format.json_schema.schema", default={})):
121 continue
122 if resp.get("texts") or resp.get("tool_name"):
123 resp |= {
124 "success": True,
125 "prefix": prefix,
126 "model_name": model_name,
127 "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
128 }
129 resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
130 return resp
131 except Exception as e:
132 logger.error(f"OpenAI API error: {e}")
133 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
134 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
135
136
137async def single_api_chat_completions(
138 client: Client,
139 status_msg: Message | None,
140 openai: AsyncOpenAI,
141 params: dict,
142 *,
143 prefix: str = "",
144 hide_thinking: bool = False,
145 silent: bool = False,
146 retry: int = 0,
147 max_retries: int = 3,
148 **kwargs,
149) -> dict:
150 """Get OpenAI Chat Completions via single API.
151
152 Returns:
153 dict: {"texts": str, "thoughts": str, "tool_name": str, "tool_args": str, "sent_messages": list[Message]}
154 """
155 if retry > max_retries:
156 return {"texts": "", "thoughts": "", "tool_name": "", "tool_args": "", "sent_messages": []}
157 answers = "" # all model responses
158 thoughts = "" # all model thoughts
159 tool_name = ""
160 tool_args = ""
161 runtime_texts = "" # for a single telegram message
162 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
163 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
164 sent_messages = []
165 resp = ""
166 try:
167 reasoning_chat_flag = None # 是否是推理模型
168 is_reasoning = False # 是否正在推理
169 async for chunk in await openai.chat.completions.create(**params):
170 resp = trim_none(chunk.model_dump())
171 logger.trace(resp)
172 chunk_answer = glom(resp, "choices.0.delta.content", default="") or ""
173 chunk_thinking = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
174 tool_name = tool_name or glom(resp, "choices.0.delta.tool_calls.0.function.name", default="")
175 tool_args += glom(resp, "choices.0.delta.tool_calls.0.function.arguments", default="") or ""
176 if not chunk_answer and not chunk_thinking:
177 continue
178 if reasoning_chat_flag is None and chunk_thinking:
179 reasoning_chat_flag = True
180 if chunk_thinking and not is_reasoning: # 首次收到推理内容
181 is_reasoning = True
182 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
183 elif chunk_thinking and is_reasoning: # 收到推理内容且正在思考
184 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
185 elif reasoning_chat_flag is True and is_reasoning: # 收到回答, 关闭推理标志
186 is_reasoning = False
187 runtime_texts = chunk_answer.lstrip()
188 else:
189 runtime_texts += chunk_answer
190
191 thoughts += chunk_thinking
192 answers += chunk_answer
193 if hide_thinking and is_reasoning and not tool_args:
194 continue
195 runtime_texts = beautify_llm_response(runtime_texts)
196 length = await count_without_entities(prefix + runtime_texts)
197 if length <= TEXT_LENGTH - 10: # leave some flexibility
198 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
199 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
200 else: # answers is too long, split it into multiple messages
201 parts = await smart_split(prefix + runtime_texts)
202 if len(parts) == 1:
203 continue
204 if is_reasoning:
205 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
206 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
207 else:
208 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
209 runtime_texts = parts[-1] # keep the last part
210 if not silent:
211 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
212 sent_messages.append(status_msg)
213 status_mid = status_msg.id
214 if tool_name.strip():
215 return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
216 # all chunks are processed
217 if not (answers.strip() or thoughts.strip()): # empty response
218 await modify_progress(message=status_msg, text=str(resp), force_update=True)
219 return await single_api_chat_completions(
220 client,
221 status_msg,
222 openai,
223 params=params,
224 prefix=prefix,
225 retry=retry + 1,
226 max_retries=max_retries,
227 hide_thinking=hide_thinking,
228 silent=silent,
229 **kwargs,
230 )
231
232 if not thoughts: # no structured thinking in response
233 thoughts, answers = split_reasoning(answers)
234
235 # answers = add_search_results_to_response(config.get("search_results", []), answers)
236 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
237 quoted = answers.strip()
238 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
239 else: # total length is too long, answers are splitted into multiple messages
240 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
241
242 except Exception as e:
243 error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}\n{resp}"
244 logger.error(error)
245 with contextlib.suppress(Exception):
246 await modify_progress(status_msg, text=error, force_update=True, **kwargs)
247 [await delete_message(msg) for msg in sent_messages]
248 if retry + 1 < max_retries:
249 return await single_api_chat_completions(
250 client,
251 status_msg,
252 openai,
253 params=params,
254 prefix=prefix,
255 retry=retry + 1,
256 max_retries=max_retries,
257 hide_thinking=hide_thinking,
258 silent=silent,
259 **kwargs,
260 )
261 return {}
262 return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
263
264
265def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
266 if not skills:
267 return contexts
268 if glom(contexts, "0.role", default="") != "system":
269 contexts.insert(0, {"role": "system", "content": skills})
270 return contexts
271 system_prompt = contexts[0]["content"]
272 if isinstance(system_prompt, str) and skills not in system_prompt:
273 system_prompt = f"{system_prompt}\n{skills}"
274 if isinstance(system_prompt, list) and {"type": "text", "text": skills} not in system_prompt:
275 system_prompt.append({"type": "text", "text": skills})
276 contexts[0] = {"role": "system", "content": system_prompt}
277 return contexts
278
279
280def is_valid_response(resp: dict, schema: dict) -> bool:
281 """Check if the response is valid."""
282 if not schema:
283 return bool(resp.get("texts"))
284 if not resp.get("texts"):
285 return False
286 try:
287 data = json.loads(resp["texts"])
288 validate(instance=data, schema=schema)
289 except (JSONDecodeError, ValidationError) as e:
290 logger.error(f"Invalid JSONSchema response: {e}")
291 return False
292 return True