main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import contextlib
  5import json
  6from json import JSONDecodeError
  7
  8from glom import glom
  9from google import genai
 10from google.genai import types
 11from jsonschema import ValidationError, validate
 12from loguru import logger
 13from pyrogram.client import Client
 14from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 15from pyrogram.types import Message, ReplyParameters
 16
 17from ai.texts.contexts import get_gemini_contexts
 18from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
 19from config import AI, PROXY, TEXT_LENGTH
 20from messages.progress import modify_progress
 21from messages.utils import blockquote, count_without_entities, quote, smart_split
 22from networking import flatten_rediercts
 23from utils import number_to_emoji, strings_list
 24
 25
 26async def gemini_chat_completion(
 27    client: Client,
 28    message: Message,
 29    *,
 30    prefix: str = "",
 31    model_id: str = AI.GEMINI_MODEL_ID,
 32    model_name: str = AI.GEMINI_MODEL_ID,
 33    gemini_base_url: str = AI.GEMINI_BASE_URL,
 34    gemini_api_keys: str = AI.GEMINI_API_KEYS,
 35    gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
 36    gemini_generate_content_config: str | dict = "",
 37    gemini_proxy: str | None = PROXY.GOOGLE,
 38    gemini_append_grounding: bool = True,
 39    additional_contexts: list[dict] | None = None,  # additional contexts to append to the contexts
 40    skills: str = "",
 41    hide_thinking: bool = False,
 42    add_sender: bool | None = None,
 43    silent: bool = False,
 44    max_retries: int = 3,
 45    **kwargs,
 46) -> dict:
 47    """Get OpenAI Chat Completions.
 48
 49    Returns:
 50            dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
 51    """
 52    if not prefix:
 53        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 54
 55    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 56        status_msg = None
 57    else:
 58        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 59
 60    sent_messages = [status_msg]
 61    for api_key in strings_list(gemini_api_keys, shuffle=True):
 62        try:
 63            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
 64            gemini = genai.Client(api_key=api_key, http_options=http_options)
 65            params: dict = {
 66                "model": model_id,
 67                "contents": await get_gemini_contexts(
 68                    client,
 69                    message,
 70                    gemini,
 71                    add_sender=add_sender,
 72                    additional_contexts=additional_contexts,
 73                ),
 74            }
 75            if skills:
 76                gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
 77            if conf := literal_eval(gemini_generate_content_config):
 78                params["config"] = conf
 79            logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
 80            resp = await single_api_generate_content(
 81                client,
 82                status_msg,
 83                gemini,
 84                params=params,
 85                prefix=prefix,
 86                silent=silent,
 87                max_retries=max_retries,
 88                append_grounding=gemini_append_grounding,
 89                hide_thinking=hide_thinking,
 90                **kwargs,
 91            )
 92            if resp.get("texts"):
 93                if not is_valid_response(resp, glom(params, "config.responseJsonSchema", default={})):
 94                    continue
 95                sent_messages.extend(resp.get("sent_messages", []))
 96                return {
 97                    "success": True,
 98                    "texts": resp["texts"],
 99                    "thoughts": resp["thoughts"],
100                    "prefix": prefix,
101                    "model_name": model_name,
102                    "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
103                }
104        except Exception as e:
105            logger.error(f"Gemini API error: {e}")
106            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
107    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
108
109
110async def single_api_generate_content(
111    client: Client,
112    status_msg: Message | None,
113    gemini: genai.Client,
114    params: dict,
115    *,
116    prefix: str = "",
117    retry: int = 0,
118    max_retries: int = 3,
119    append_grounding: bool = True,
120    hide_thinking: bool = False,
121    silent: bool = False,
122    **kwargs,
123) -> dict:
124    """Get Gemini Chat Completions via single API.
125
126    Returns:
127        dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
128    """
129    if retry > max_retries:
130        return {"texts": "", "thoughts": "", "sent_messages": []}
131    answers = ""  # all model responses
132    thoughts = ""  # all model thoughts
133    runtime_texts = ""  # for a single telegram message
134    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
135    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
136    sent_messages = []
137    resp = {}
138    try:
139        reasoning_chat_flag = None  # 是否是推理模型
140        is_reasoning = False  # 是否正在推理
141        async for chunk in await gemini.aio.models.generate_content_stream(**params):
142            resp = parse_chunk(chunk)
143            chunk_answer = resp.get("texts", "")
144            chunk_thinking = resp.get("thinking", "")
145            if chunk_thinking:
146                reasoning_chat_flag = True
147            if chunk_thinking and not is_reasoning:  # 首次收到推理内容
148                is_reasoning = True
149                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
150            elif chunk_thinking and is_reasoning:  # 收到推理内容且正在思考
151                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
152            elif reasoning_chat_flag is True and is_reasoning:  # Receiving response, close reasoning flag
153                is_reasoning = False
154                runtime_texts = chunk_answer.lstrip()
155            else:
156                runtime_texts += chunk_answer
157
158            thoughts += chunk_thinking
159            answers += chunk_answer
160            if hide_thinking and is_reasoning:
161                continue
162
163            runtime_texts = beautify_llm_response(runtime_texts)
164            length = await count_without_entities(prefix + runtime_texts)
165            if length <= TEXT_LENGTH:
166                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
167                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
168            else:  # answers is too long, split it into multiple messages
169                parts = await smart_split(prefix + runtime_texts)
170                if len(parts) == 1:
171                    continue
172                if is_reasoning:
173                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
174                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
175                else:
176                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
177                    runtime_texts = parts[-1]  # keep the last part
178                    if not silent:
179                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
180                        sent_messages.append(status_msg)
181                        status_mid = status_msg.id
182
183        await gemini.aio.aclose()
184        # all chunks are processed
185        if not answers.strip() and not thoughts.strip():  # empty response
186            return await single_api_generate_content(
187                client,
188                status_msg,
189                gemini,
190                params=params,
191                prefix=prefix,
192                silent=silent,
193                retry=retry + 1,
194                max_retries=max_retries,
195                append_grounding=append_grounding,
196                hide_thinking=hide_thinking,
197                **kwargs,
198            )
199        if append_grounding:  # add grounding to the response
200            answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
201            runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
202        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
203            quoted = answers.strip()
204            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
205        else:  # total length is too long, answers are splitted into multiple messages
206            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
207
208    except Exception as e:
209        error = f"{e}\n{resp}"
210        logger.error(error)
211        with contextlib.suppress(Exception):
212            await modify_progress(message=status_msg, text=error, force_update=True)
213            [await modify_progress(msg, del_status=True) for msg in sent_messages]
214        if retry + 1 < max_retries:
215            return await single_api_generate_content(
216                client,
217                status_msg,
218                gemini,
219                params=params,
220                prefix=prefix,
221                silent=silent,
222                retry=retry + 1,
223                max_retries=max_retries,
224                append_grounding=append_grounding,
225                hide_thinking=hide_thinking,
226                **kwargs,
227            )
228        return {}
229    return {"texts": answers, "thoughts": thoughts, "sent_messages": sent_messages}
230
231
232def parse_chunk(chunk: types.GenerateContentResponse) -> dict:
233    """Parse gemini response, includes texts, image and websearch."""
234    data = trim_none(chunk.model_dump())
235    data.pop("sdk_http_response", None)
236    data.pop("usage_metadata", None)
237    logger.trace(data)
238    parts = glom(data, "candidates.0.content.parts", default=[]) or []
239    texts = "".join([p.get("text", "") for p in parts if not p.get("thought")])
240    thinking = "".join([p.get("text", "") for p in parts if p.get("thought")])
241    return {
242        "texts": beautify_llm_response(texts, newline_level=2),
243        "thinking": beautify_llm_response(thinking, newline_level=2),
244        "grounding_chunks": glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or [],
245        "grounding_supports": glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or [],
246    }
247
248
249async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
250    urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
251    tasks = [flatten_rediercts(url) for url in urls]
252    try:
253        flatten_urls = await asyncio.gather(*tasks)
254        index2url = flatten_urls
255    except Exception as e:
256        logger.warning(e)
257        index2url = urls
258    logger.trace(f"Grounding URLs: {index2url}")
259    for support in grounding_supports:
260        indices: list[int] = support.get("grounding_chunk_indices", [])
261        logger.trace(f"Add grounding indices: {indices}")
262        indices_with_url = " ".join([f"[[{idx + 1}]]({glom(index2url, str(idx), default='https://www.google.com')})" for idx in indices])
263        if segment := glom(support, "segment.text", default=""):
264            answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
265    for idx, grounding in enumerate(grounding_chunks):
266        if idx > 9:
267            break
268        title = glom(grounding, "web.title", default="Web")
269        url = glom(index2url, str(idx), default="https://www.google.com")
270        if url in answers:
271            answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
272    return answers
273
274
275def is_valid_response(resp: dict, schema: dict) -> bool:
276    """Check if the response is valid."""
277    if not schema:
278        return bool(resp.get("texts"))
279    if not resp.get("texts"):
280        return False
281    try:
282        data = json.loads(resp["texts"])
283        validate(instance=data, schema=schema)
284    except (JSONDecodeError, ValidationError) as e:
285        logger.error(f"Invalid JSONSchema response: {e}")
286        return False
287    return True