main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import contextlib
  4import hashlib
  5from typing import Literal
  6from urllib.parse import quote_plus
  7
  8from anthropic import AsyncAnthropic, DefaultAioHttpClient
  9from glom import Coalesce, glom
 10from loguru import logger
 11from pyrogram.client import Client
 12from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 13from pyrogram.types import Message, ReplyParameters
 14
 15from ai.texts.contexts import get_anthropic_contexts
 16from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
 17from config import AI, PROXY, TEXT_LENGTH
 18from messages.progress import modify_progress
 19from messages.utils import blockquote, count_without_entities, delete_message, quote, smart_split
 20from utils import number_to_emoji, rand_string, strings_list
 21
 22
 23async def anthropic_responses(
 24    client: Client,
 25    message: Message,
 26    *,
 27    prefix: str = "",
 28    model_id: str = AI.ANTHROPIC_MODEL_ID,
 29    model_name: str = AI.ANTHROPIC_MODEL_ID,
 30    anthropic_base_url: str = AI.ANTHROPIC_BASE_URL,
 31    anthropic_api_keys: str = AI.ANTHROPIC_API_KEYS,
 32    anthropic_client_config: str | dict = "",
 33    anthropic_default_headers: str | dict = "",
 34    anthropic_responses_config: str | dict = "",
 35    anthropic_proxy: str | None = PROXY.ANTHROPIC,
 36    cache_response_ttl: int = 0,
 37    anthropic_media_send_as: Literal["base64", "file_id"] = "base64",
 38    anthropic_append_citation: bool = True,
 39    skills: str = "",
 40    hide_thinking: bool = False,
 41    add_sender: bool | None = None,
 42    silent: bool = False,
 43    max_retries: int = 3,
 44    **kwargs,
 45) -> dict:
 46    """Get Anthropic Responses.
 47
 48    Returns:
 49        dict: {"texts": str, "thoughts": str,  "prefix": str, "model_name": str, "sent_messages": list[Message]}
 50    """
 51    if not prefix:
 52        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 53
 54    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 55        status_msg = None
 56    else:
 57        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 58
 59    sent_messages = [status_msg]
 60    cache_hour = round(cache_response_ttl // 3600)
 61    try:
 62        anthropic_client = {}
 63        if literal_eval(anthropic_client_config):
 64            anthropic_client |= literal_eval(anthropic_client_config)
 65        if literal_eval(anthropic_default_headers):
 66            anthropic_client |= {"default_headers": literal_eval(anthropic_default_headers)}
 67        if anthropic_proxy:
 68            anthropic_client |= {"http_client": DefaultAioHttpClient(proxy=anthropic_proxy)}
 69    except Exception as e:
 70        logger.error(f"Anthropic client setup error: {e}")
 71        return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 72    for api_key in strings_list(anthropic_api_keys, shuffle=True):
 73        try:
 74            anthropic_client |= {"base_url": anthropic_base_url, "api_key": api_key}
 75            logger.trace(f"AsyncAnthropic(**{anthropic_client})")
 76            anthropic = AsyncAnthropic(**anthropic_client)
 77            params: dict = {
 78                "model": model_id,
 79                "max_tokens": 4096,
 80                "messages": await get_anthropic_contexts(
 81                    client,
 82                    message,
 83                    anthropic=anthropic,
 84                    cache_hour=cache_hour,
 85                    media_send_as=anthropic_media_send_as,
 86                    add_sender=add_sender,
 87                ),
 88            }
 89            if literal_eval(anthropic_responses_config):
 90                params |= literal_eval(anthropic_responses_config)
 91            if skills:
 92                params |= {"system": await load_skills(skills)}
 93            logger.debug(f"anthropic.messages.create(**{params})")
 94            resp = await single_api_response(
 95                client,
 96                status_msg,
 97                anthropic,
 98                params=params,
 99                prefix=prefix,
100                hide_thinking=hide_thinking,
101                silent=silent,
102                max_retries=max_retries,
103                append_citation=anthropic_append_citation,
104                **kwargs,
105            )
106            if not resp.get("texts"):
107                continue
108            sent_messages.extend(resp.get("sent_messages", []))
109            return {
110                "success": True,
111                "texts": resp["texts"],
112                "thoughts": resp["thoughts"],
113                "prefix": prefix,
114                "model_name": model_name,
115                "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
116            }
117        except Exception as e:
118            logger.error(f"Anthropic API error: {e}")
119            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
120    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
121
122
123async def single_api_response(
124    client: Client,
125    status_msg: Message | None,
126    anthropic: AsyncAnthropic,
127    params: dict,
128    *,
129    prefix: str = "",
130    append_citation: bool = True,
131    hide_thinking: bool = False,
132    silent: bool = False,
133    retry: int = 0,
134    max_retries: int = 3,
135    **kwargs,
136) -> dict:
137    """Get Anthropic Chat Completions via single API.
138
139    Returns:
140        dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
141    """
142    if retry > max_retries:
143        return {"texts": "", "thoughts": "", "sent_messages": []}
144    answers = ""  # all model responses
145    thoughts = ""  # all model thoughts
146    runtime_texts = ""  # for a single telegram message
147    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
148    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
149    sent_messages = []
150    try:
151        is_reasoning = False
152        async with anthropic.beta.messages.stream(**params) as stream:
153            async for chunk in stream:
154                resp = trim_none(chunk.model_dump())
155                logger.trace(resp)
156                response_type = glom(resp, Coalesce("delta.type", "content_block.type"), default="") or ""
157                chunk_answer = glom(resp, "delta.text", default="") or ""
158                chunk_thinking = glom(resp, "delta.thinking", default="") or ""
159                # 设置推理标志
160                if response_type == "thinking_delta":  # 正在推理
161                    is_reasoning = True
162                elif response_type == "text_delta":  # 推理结束
163                    is_reasoning = False
164
165                if response_type == "thinking" and len(thoughts) == 0:  # 首次收到推理内容
166                    runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
167                elif chunk_thinking:  # 收到推理内容
168                    runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
169
170                if response_type == "text":  # 收到初始回答
171                    runtime_texts = chunk_answer.lstrip()
172                else:
173                    runtime_texts += chunk_answer
174
175                if not chunk_answer and not chunk_thinking:
176                    continue
177                thoughts += chunk_thinking
178                answers += chunk_answer
179                if hide_thinking and is_reasoning:
180                    continue
181                runtime_texts = beautify_llm_response(runtime_texts)
182                length = await count_without_entities(prefix + runtime_texts)
183                if length <= TEXT_LENGTH - 10:  # leave some flexibility
184                    if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
185                        await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
186                else:  # answers is too long, split it into multiple messages
187                    parts = await smart_split(prefix + runtime_texts)
188                    if len(parts) == 1:
189                        continue
190                    if is_reasoning:
191                        runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
192                        await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
193                    else:
194                        await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
195                        runtime_texts = parts[-1]  # keep the last part
196                        if not silent:
197                            status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
198                            sent_messages.append(status_msg)
199                            status_mid = status_msg.id
200
201        # all chunks are processed
202        if not answers.strip() and not thoughts.strip():  # empty response
203            return await single_api_response(
204                client,
205                status_msg,
206                anthropic,
207                params=params,
208                prefix=prefix,
209                retry=retry + 1,
210                max_retries=max_retries,
211                hide_thinking=hide_thinking,
212                silent=silent,
213                **kwargs,
214            )
215        thoughts, answers = parse_final_block(resp, thoughts, answers, append_citation=append_citation)
216        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
217            quoted = answers.strip()
218            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
219        else:  # total length is too long, answers are splitted into multiple messages
220            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
221
222    except Exception as e:
223        error = f"{EMOJI_TEXT_BOT}BOT请求失败, 重试次数: {retry + 1}/{max_retries}\n{e}"
224        if "resp" in locals():
225            error += f"\n{resp}"
226        logger.error(error)
227        with contextlib.suppress(Exception):
228            await modify_progress(status_msg, text=error, force_update=True, **kwargs)
229            [await delete_message(msg) for msg in sent_messages]
230        if retry + 1 < max_retries:
231            return await single_api_response(
232                client,
233                status_msg,
234                anthropic,
235                params=params,
236                prefix=prefix,
237                append_citation=append_citation,
238                retry=retry + 1,
239                max_retries=max_retries,
240                hide_thinking=hide_thinking,
241                silent=silent,
242                **kwargs,
243            )
244        return {}
245    return {
246        "texts": answers,
247        "thoughts": thoughts,
248        "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
249    }
250
251
252def parse_final_block(chunk: dict, thoughts: str, answers: str, *, append_citation: bool) -> tuple[str, str]:
253    if not append_citation:
254        return thoughts, answers
255    if chunk.get("type") != "message_stop":
256        return thoughts, answers
257    thoughts = ""
258    texts = ""
259    citations = {}  # {cite_key: {index:int, title:str, url:str}}
260    for item in glom(chunk, "message.content", default=[]):
261        if item.get("type") == "thinking":
262            thoughts += item.get("thinking", "")
263        elif item.get("type") == "text":
264            texts += item.get("text", "")
265            for citation in glom(item, "citations", default=[]):
266                title = citation.get("title") or rand_string(8)
267                url = citation.get("url") or f"https://google.com/search?q=/{quote_plus(title)}"
268                cite_key = hashlib.sha256(f"{title}{url}".encode()).hexdigest()
269                cite_index = glom(citations, f"{cite_key}.index", default=None) or len(citations) + 1
270                citations[cite_key] = {"index": cite_index, "title": title, "url": url}
271                texts += f" [[{cite_index}]]({url})"
272    # append citations
273    for x in sorted(citations.values(), key=lambda x: x["index"]):
274        texts += f"\n{number_to_emoji(x['index'])}[{x['title']}]({x['url']})"
275    return thoughts.strip(), texts.strip()