main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import contextlib
4import hashlib
5from typing import Literal
6from urllib.parse import quote_plus
7
8from anthropic import AsyncAnthropic, DefaultAioHttpClient
9from glom import Coalesce, glom
10from loguru import logger
11from pyrogram.client import Client
12from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
13from pyrogram.types import Message, ReplyParameters
14
15from ai.texts.contexts import get_anthropic_contexts
16from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
17from config import AI, PROXY, TEXT_LENGTH
18from messages.progress import modify_progress
19from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
20from utils import number_to_emoji, rand_string, strings_list
21
22
23async def anthropic_responses(
24 client: Client,
25 message: Message,
26 *,
27 prefix: str = "",
28 model_id: str = AI.ANTHROPIC_MODEL_ID,
29 model_name: str = AI.ANTHROPIC_MODEL_ID,
30 anthropic_base_url: str = AI.ANTHROPIC_BASE_URL,
31 anthropic_api_keys: str = AI.ANTHROPIC_API_KEYS,
32 anthropic_client_config: str | dict = "",
33 anthropic_default_headers: str | dict = "",
34 anthropic_responses_config: str | dict = "",
35 anthropic_proxy: str | None = PROXY.ANTHROPIC,
36 cache_response_ttl: int = 0,
37 anthropic_media_send_as: Literal["base64", "file_id"] = "base64",
38 anthropic_append_citation: bool = True,
39 skills: str = "",
40 hide_thinking: bool = False,
41 add_sender: bool | None = None,
42 silent: bool = False,
43 max_retries: int = 3,
44 **kwargs,
45) -> dict:
46 """Get Anthropic Responses.
47
48 Returns:
49 dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
50 """
51 if not prefix:
52 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
53
54 if silent or not kwargs.get("show_progress"): # noqa: SIM108
55 status_msg = None
56 else:
57 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
58
59 sent_messages = [status_msg]
60 cache_hour = round(cache_response_ttl // 3600)
61 try:
62 anthropic_client = {}
63 if literal_eval(anthropic_client_config):
64 anthropic_client |= literal_eval(anthropic_client_config)
65 if literal_eval(anthropic_default_headers):
66 anthropic_client |= {"default_headers": literal_eval(anthropic_default_headers)}
67 if anthropic_proxy:
68 anthropic_client |= {"http_client": DefaultAioHttpClient(proxy=anthropic_proxy)}
69 except Exception as e:
70 logger.error(f"Anthropic client setup error: {e}")
71 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
72 for api_key in strings_list(anthropic_api_keys, shuffle=True):
73 try:
74 anthropic_client |= {"base_url": anthropic_base_url, "api_key": api_key}
75 logger.trace(f"AsyncAnthropic(**{anthropic_client})")
76 anthropic = AsyncAnthropic(**anthropic_client)
77 params: dict = {
78 "model": model_id,
79 "max_tokens": 4096,
80 "messages": await get_anthropic_contexts(
81 client,
82 message,
83 anthropic=anthropic,
84 cache_hour=cache_hour,
85 media_send_as=anthropic_media_send_as,
86 add_sender=add_sender,
87 ),
88 }
89 if literal_eval(anthropic_responses_config):
90 params |= literal_eval(anthropic_responses_config)
91 if skills:
92 params |= {"system": await load_skills(skills)}
93 logger.debug(f"anthropic.messages.create(**{params})")
94 resp = await single_api_response(
95 client,
96 status_msg,
97 anthropic,
98 params=params,
99 prefix=prefix,
100 hide_thinking=hide_thinking,
101 silent=silent,
102 max_retries=max_retries,
103 append_citation=anthropic_append_citation,
104 **kwargs,
105 )
106 if not resp.get("texts"):
107 continue
108 sent_messages.extend(resp.get("sent_messages", []))
109 return {
110 "success": True,
111 "texts": resp["texts"],
112 "thoughts": resp["thoughts"],
113 "prefix": prefix,
114 "model_name": model_name,
115 "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
116 }
117 except Exception as e:
118 logger.error(f"Anthropic API error: {e}")
119 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
120 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
121
122
123async def single_api_response(
124 client: Client,
125 status_msg: Message | None,
126 anthropic: AsyncAnthropic,
127 params: dict,
128 *,
129 prefix: str = "",
130 append_citation: bool = True,
131 hide_thinking: bool = False,
132 silent: bool = False,
133 retry: int = 0,
134 max_retries: int = 3,
135 **kwargs,
136) -> dict:
137 """Get Anthropic Chat Completions via single API.
138
139 Returns:
140 dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
141 """
142 if retry > max_retries:
143 return {"texts": "", "thoughts": "", "sent_messages": []}
144 answers = "" # all model responses
145 thoughts = "" # all model thoughts
146 runtime_texts = "" # for a single telegram message
147 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
148 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
149 sent_messages = []
150 try:
151 is_reasoning = False
152 async with anthropic.beta.messages.stream(**params) as stream:
153 async for chunk in stream:
154 resp = trim_none(chunk.model_dump())
155 logger.trace(resp)
156 response_type = glom(resp, Coalesce("delta.type", "content_block.type"), default="") or ""
157 chunk_answer = glom(resp, "delta.text", default="") or ""
158 chunk_thinking = glom(resp, "delta.thinking", default="") or ""
159 # 设置推理标志
160 if response_type == "thinking_delta": # 正在推理
161 is_reasoning = True
162 elif response_type == "text_delta": # 推理结束
163 is_reasoning = False
164
165 if response_type == "thinking" and len(thoughts) == 0: # 首次收到推理内容
166 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
167 elif chunk_thinking: # 收到推理内容
168 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
169
170 if response_type == "text": # 收到初始回答
171 runtime_texts = chunk_answer.lstrip()
172 else:
173 runtime_texts += chunk_answer
174
175 if not chunk_answer and not chunk_thinking:
176 continue
177 thoughts += chunk_thinking
178 answers += chunk_answer
179 if hide_thinking and is_reasoning:
180 continue
181 runtime_texts = beautify_llm_response(runtime_texts)
182 length = await count_without_entities(prefix + runtime_texts)
183 if length <= TEXT_LENGTH - 10: # leave some flexibility
184 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
185 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
186 else: # answers is too long, split it into multiple messages
187 parts = await smart_split(prefix + runtime_texts)
188 if len(parts) == 1:
189 continue
190 if is_reasoning:
191 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
192 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
193 else:
194 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
195 runtime_texts = parts[-1] # keep the last part
196 if not silent:
197 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
198 sent_messages.append(status_msg)
199 status_mid = status_msg.id
200
201 # all chunks are processed
202 if not answers.strip() and not thoughts.strip(): # empty response
203 return await single_api_response(
204 client,
205 status_msg,
206 anthropic,
207 params=params,
208 prefix=prefix,
209 retry=retry + 1,
210 max_retries=max_retries,
211 hide_thinking=hide_thinking,
212 silent=silent,
213 **kwargs,
214 )
215 thoughts, answers = parse_final_block(resp, thoughts, answers, append_citation=append_citation)
216 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
217 quoted = answers.strip()
218 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
219 else: # total length is too long, answers are splitted into multiple messages
220 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
221
222 except Exception as e:
223 error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
224 if "resp" in locals():
225 error += f"\n{resp}"
226 logger.error(error)
227 with contextlib.suppress(Exception):
228 await modify_progress(status_msg, text=error, force_update=True, **kwargs)
229 [await delete_message(msg) for msg in sent_messages]
230 if retry + 1 < max_retries:
231 return await single_api_response(
232 client,
233 status_msg,
234 anthropic,
235 params=params,
236 prefix=prefix,
237 append_citation=append_citation,
238 retry=retry + 1,
239 max_retries=max_retries,
240 hide_thinking=hide_thinking,
241 silent=silent,
242 **kwargs,
243 )
244 return {}
245 return {
246 "texts": answers,
247 "thoughts": thoughts,
248 "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
249 }
250
251
252def parse_final_block(chunk: dict, thoughts: str, answers: str, *, append_citation: bool) -> tuple[str, str]:
253 if not append_citation:
254 return thoughts, answers
255 if chunk.get("type") != "message_stop":
256 return thoughts, answers
257 thoughts = ""
258 texts = ""
259 citations = {} # {cite_key: {index:int, title:str, url:str}}
260 for item in glom(chunk, "message.content", default=[]):
261 if item.get("type") == "thinking":
262 thoughts += item.get("thinking", "")
263 elif item.get("type") == "text":
264 texts += item.get("text", "")
265 for citation in glom(item, "citations", default=[]):
266 title = citation.get("title") or rand_string(8)
267 url = citation.get("url") or f"https://google.com/search?q=/{quote_plus(title)}"
268 cite_key = hashlib.sha256(f"{title}{url}".encode()).hexdigest()
269 cite_index = glom(citations, f"{cite_key}.index", default=None) or len(citations) + 1
270 citations[cite_key] = {"index": cite_index, "title": title, "url": url}
271 texts += f" [[{cite_index}]]({url})"
272 # append citations
273 for x in sorted(citations.values(), key=lambda x: x["index"]):
274 texts += f"\n{number_to_emoji(x['index'])}[{x['title']}]({x['url']})"
275 return thoughts.strip(), texts.strip()