main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4
5from glom import glom
6from loguru import logger
7from openai import AsyncOpenAI, DefaultAsyncHttpxClient
8from pyrogram.client import Client
9from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
10from pyrogram.types import Message, ReplyParameters
11
12from ai.texts.contexts import get_openai_completion_contexts
13from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, split_reasoning, trim_none
14from config import AI, PROXY, TEXT_LENGTH
15from messages.progress import modify_progress
16from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
17from utils import strings_list
18
19
20async def openai_chat_completions(
21 client: Client,
22 message: Message,
23 *,
24 prefix: str = "",
25 model_id: str = AI.OPENAI_MODEL_ID,
26 model_name: str = AI.OPENAI_MODEL_ID,
27 openai_base_url: str = AI.OPENAI_BASE_URL,
28 openai_api_keys: str = AI.OPENAI_API_KEYS,
29 openai_client_config: str | dict = "",
30 openai_default_headers: str | dict = "",
31 openai_completions_config: str | dict = "",
32 openai_proxy: str | None = PROXY.OPENAI,
33 openai_system_prompt: str = "",
34 openai_contexts: list[dict] | None = None,
35 openai_tools: list[dict] | None = None,
36 skills: str = "",
37 hide_thinking: bool = False,
38 silent: bool = False,
39 max_retries: int = 3,
40 **kwargs,
41) -> dict:
42 """Get OpenAI Chat Completions.
43
44 Returns:
45 dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
46 """
47 if not prefix:
48 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
49
50 if silent or not kwargs.get("show_progress"): # noqa: SIM108
51 status_msg = None
52 else:
53 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
54
55 sent_messages = [status_msg]
56 try:
57 openai_client = {}
58 if literal_eval(openai_client_config):
59 openai_client |= literal_eval(openai_client_config)
60 if literal_eval(openai_default_headers):
61 openai_client |= {"default_headers": literal_eval(openai_default_headers)}
62 if openai_proxy:
63 openai_client |= {"http_client": DefaultAsyncHttpxClient(proxy=openai_proxy)}
64 contexts = openai_contexts or await get_openai_completion_contexts(client, message)
65 if openai_system_prompt and glom(contexts, "0.role", default="") != "system":
66 contexts.insert(0, {"role": "system", "content": openai_system_prompt})
67 if skills:
68 contexts = inject_skills(contexts, skills=await load_skills(skills))
69 params = {"model": model_id, "messages": contexts, "stream": True}
70 if literal_eval(openai_completions_config):
71 params |= literal_eval(openai_completions_config)
72 if openai_tools:
73 params |= {"tools": openai_tools, "tool_choice": "auto"}
74 logger.debug(f"openai.chat.completions.create(**{params})")
75 except Exception as e:
76 logger.error(f"OpenAI client setup error: {e}")
77 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
78 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
79
80 for api_key in strings_list(openai_api_keys, shuffle=True):
81 try:
82 openai_client |= {"base_url": openai_base_url, "api_key": api_key}
83 logger.trace(f"AsyncOpenAI(**{openai_client})")
84 openai = AsyncOpenAI(**openai_client)
85 resp = await single_api_chat_completions(
86 client,
87 status_msg,
88 openai,
89 params=params,
90 prefix=prefix,
91 hide_thinking=hide_thinking,
92 silent=silent,
93 max_retries=max_retries,
94 **kwargs,
95 )
96 if resp.get("texts") or resp.get("tool_name"):
97 resp |= {
98 "success": True,
99 "prefix": prefix,
100 "model_name": model_name,
101 "sent_messages": [m for m in sent_messages + resp["sent_messages"] if isinstance(m, Message)],
102 }
103 resp |= {"progress": status_msg} if isinstance(status_msg, Message) else {}
104 return resp
105 except Exception as e:
106 logger.error(f"OpenAI API error: {e}")
107 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
108 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
109
110
111async def single_api_chat_completions(
112 client: Client,
113 status_msg: Message | None,
114 openai: AsyncOpenAI,
115 params: dict,
116 *,
117 prefix: str = "",
118 hide_thinking: bool = False,
119 silent: bool = False,
120 retry: int = 0,
121 max_retries: int = 3,
122 **kwargs,
123) -> dict:
124 """Get OpenAI Chat Completions via single API.
125
126 Returns:
127 dict: {"texts": str, "thoughts": str, "tool_name": str, "tool_args": str, "sent_messages": list[Message]}
128 """
129 if retry > max_retries:
130 return {"texts": "", "thoughts": "", "tool_name": "", "tool_args": "", "sent_messages": []}
131 answers = "" # all model responses
132 thoughts = "" # all model thoughts
133 tool_name = ""
134 tool_args = ""
135 runtime_texts = "" # for a single telegram message
136 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
137 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
138 sent_messages = []
139 resp = ""
140 try:
141 reasoning_chat_flag = None # 是否是推理模型
142 is_reasoning = False # 是否正在推理
143 async for chunk in await openai.chat.completions.create(**params):
144 resp = trim_none(chunk.model_dump())
145 logger.trace(resp)
146 chunk_answer = glom(resp, "choices.0.delta.content", default="") or ""
147 chunk_thinking = glom(resp, "choices.0.delta.reasoning_content", default="") or ""
148 tool_name = tool_name or glom(resp, "choices.0.delta.tool_calls.0.function.name", default="")
149 tool_args += glom(resp, "choices.0.delta.tool_calls.0.function.arguments", default="") or ""
150 if not chunk_answer and not chunk_thinking:
151 continue
152 if reasoning_chat_flag is None and chunk_thinking:
153 reasoning_chat_flag = True
154 if chunk_thinking and not is_reasoning: # 首次收到推理内容
155 is_reasoning = True
156 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
157 elif chunk_thinking and is_reasoning: # 收到推理内容且正在思考
158 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
159 elif reasoning_chat_flag is True and is_reasoning: # 收到回答, 关闭推理标志
160 is_reasoning = False
161 runtime_texts = chunk_answer.lstrip()
162 else:
163 runtime_texts += chunk_answer
164
165 thoughts += chunk_thinking
166 answers += chunk_answer
167 if hide_thinking and is_reasoning and not tool_args:
168 continue
169 runtime_texts = beautify_llm_response(runtime_texts)
170 length = await count_without_entities(prefix + runtime_texts)
171 if length <= TEXT_LENGTH - 10: # leave some flexibility
172 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
173 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
174 else: # answers is too long, split it into multiple messages
175 parts = await smart_split(prefix + runtime_texts)
176 if len(parts) == 1:
177 continue
178 if is_reasoning:
179 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
180 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
181 else:
182 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
183 runtime_texts = parts[-1] # keep the last part
184 if not silent:
185 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
186 sent_messages.append(status_msg)
187 status_mid = status_msg.id
188 if tool_name.strip():
189 return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
190 # all chunks are processed
191 if not (answers.strip() or thoughts.strip()): # empty response
192 await modify_progress(message=status_msg, text=str(resp), force_update=True)
193 return await single_api_chat_completions(
194 client,
195 status_msg,
196 openai,
197 params=params,
198 prefix=prefix,
199 retry=retry + 1,
200 max_retries=max_retries,
201 hide_thinking=hide_thinking,
202 silent=silent,
203 **kwargs,
204 )
205
206 if not thoughts: # no structured thinking in response
207 thoughts, answers = split_reasoning(answers)
208
209 # answers = add_search_results_to_response(config.get("search_results", []), answers)
210 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
211 quoted = answers.strip()
212 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
213 else: # total length is too long, answers are splitted into multiple messages
214 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
215
216 except Exception as e:
217 error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}\n{resp}"
218 logger.error(error)
219 with contextlib.suppress(Exception):
220 await modify_progress(status_msg, text=error, force_update=True, **kwargs)
221 [await delete_message(msg) for msg in sent_messages]
222 if retry + 1 < max_retries:
223 return await single_api_chat_completions(
224 client,
225 status_msg,
226 openai,
227 params=params,
228 prefix=prefix,
229 retry=retry + 1,
230 max_retries=max_retries,
231 hide_thinking=hide_thinking,
232 silent=silent,
233 **kwargs,
234 )
235 return {"texts": answers, "thoughts": thoughts, "tool_name": tool_name.strip(), "tool_args": tool_args.strip(), "sent_messages": sent_messages}
236
237
238def inject_skills(contexts: list[dict], skills: str) -> list[dict]:
239 if not skills:
240 return contexts
241 if glom(contexts, "0.role", default="") != "system":
242 contexts.insert(0, {"role": "system", "content": skills})
243 return contexts
244 system_prompt = contexts[0]["content"]
245 if isinstance(system_prompt, str) and skills not in system_prompt:
246 system_prompt = f"{system_prompt}\n{skills}"
247 if isinstance(system_prompt, list) and {"type": "text", "text": skills} not in system_prompt:
248 system_prompt.append({"type": "text", "text": skills})
249 contexts[0] = {"role": "system", "content": system_prompt}
250 return contexts