main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import contextlib
  5
  6from glom import glom
  7from google import genai
  8from google.genai import types
  9from loguru import logger
 10from pyrogram.client import Client
 11from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
 12from pyrogram.types import Message, ReplyParameters
 13
 14from ai.texts.contexts import get_gemini_contexts
 15from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
 16from config import AI, PROXY, TEXT_LENGTH
 17from messages.progress import modify_progress
 18from messages.utils import blockquote, count_without_entities, quote, smart_split
 19from networking import flatten_rediercts
 20from utils import number_to_emoji, strings_list
 21
 22
 23async def gemini_chat_completion(
 24    client: Client,
 25    message: Message,
 26    *,
 27    prefix: str = "",
 28    model_id: str = AI.GEMINI_MODEL_ID,
 29    model_name: str = AI.GEMINI_MODEL_ID,
 30    gemini_base_url: str = AI.GEMINI_BASE_URL,
 31    gemini_api_keys: str = AI.GEMINI_API_KEYS,
 32    gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
 33    gemini_generate_content_config: str | dict = "",
 34    gemini_proxy: str | None = PROXY.GOOGLE,
 35    gemini_append_grounding: bool = True,
 36    skills: str = "",
 37    hide_thinking: bool = False,
 38    silent: bool = False,
 39    max_retries: int = 3,
 40    **kwargs,
 41) -> dict:
 42    """Get OpenAI Chat Completions.
 43
 44    Returns:
 45            dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
 46    """
 47    if not prefix:
 48        prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
 49
 50    if silent or not kwargs.get("show_progress"):  # noqa: SIM108
 51        status_msg = None
 52    else:
 53        status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
 54
 55    sent_messages = [status_msg]
 56    for api_key in strings_list(gemini_api_keys, shuffle=True):
 57        try:
 58            http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
 59            gemini = genai.Client(api_key=api_key, http_options=http_options)
 60            params: dict = {"model": model_id, "contents": await get_gemini_contexts(client, message, gemini)}
 61            if skills:
 62                gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
 63            if conf := literal_eval(gemini_generate_content_config):
 64                params["config"] = conf
 65            logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
 66            resp = await single_api_generate_content(
 67                client,
 68                status_msg,
 69                gemini,
 70                params=params,
 71                prefix=prefix,
 72                silent=silent,
 73                max_retries=max_retries,
 74                append_grounding=gemini_append_grounding,
 75                hide_thinking=hide_thinking,
 76                **kwargs,
 77            )
 78            if resp.get("texts"):
 79                sent_messages.extend(resp.get("sent_messages", []))
 80                return {
 81                    "success": True,
 82                    "texts": resp["texts"],
 83                    "thoughts": resp["thoughts"],
 84                    "prefix": prefix,
 85                    "model_name": model_name,
 86                    "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
 87                }
 88        except Exception as e:
 89            logger.error(f"Gemini API error: {e}")
 90            await modify_progress(status_msg, text=f"{e}", force_update=True, **kwargs)
 91    return {"progress": status_msg} if isinstance(status_msg, Message) else {}
 92
 93
 94async def single_api_generate_content(
 95    client: Client,
 96    status_msg: Message | None,
 97    gemini: genai.Client,
 98    params: dict,
 99    *,
100    prefix: str = "",
101    retry: int = 0,
102    max_retries: int = 3,
103    append_grounding: bool = True,
104    hide_thinking: bool = False,
105    silent: bool = False,
106    **kwargs,
107) -> dict:
108    """Get Gemini Chat Completions via single API.
109
110    Returns:
111        dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
112    """
113    if retry > max_retries:
114        return {"texts": "", "thoughts": "", "sent_messages": []}
115    answers = ""  # all model responses
116    thoughts = ""  # all model thoughts
117    runtime_texts = ""  # for a single telegram message
118    status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
119    status_mid = status_msg.id if isinstance(status_msg, Message) else 0
120    sent_messages = []
121    resp = {}
122    try:
123        reasoning_chat_flag = None  # 是否是推理模型
124        is_reasoning = False  # 是否正在推理
125        async for chunk in await gemini.aio.models.generate_content_stream(**params):
126            resp = parse_chunk(chunk)
127            chunk_answer = resp.get("texts", "")
128            chunk_thinking = resp.get("thinking", "")
129            if chunk_thinking:
130                reasoning_chat_flag = True
131            if chunk_thinking and not is_reasoning:  # 首次收到推理内容
132                is_reasoning = True
133                runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
134            elif chunk_thinking and is_reasoning:  # 收到推理内容且正在思考
135                runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
136            elif reasoning_chat_flag is True and is_reasoning:  # Receiving response, close reasoning flag
137                is_reasoning = False
138                runtime_texts = chunk_answer.lstrip()
139            else:
140                runtime_texts += chunk_answer
141
142            thoughts += chunk_thinking
143            answers += chunk_answer
144            if hide_thinking and is_reasoning:
145                continue
146
147            runtime_texts = beautify_llm_response(runtime_texts)
148            length = await count_without_entities(prefix + runtime_texts)
149            if length <= TEXT_LENGTH:
150                if len(runtime_texts.removeprefix(prefix)) > 10:  # start response if answer is not empty
151                    await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
152            else:  # answers is too long, split it into multiple messages
153                parts = await smart_split(prefix + runtime_texts)
154                if len(parts) == 1:
155                    continue
156                if is_reasoning:
157                    runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}")  # remove previous thinking
158                    await modify_progress(message=status_msg, text=parts[0], force_update=True)  # force send the first part
159                else:
160                    await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True)  # force send the first part
161                    runtime_texts = parts[-1]  # keep the last part
162                    if not silent:
163                        status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid))  # the new message
164                        sent_messages.append(status_msg)
165                        status_mid = status_msg.id
166
167        await gemini.aio.aclose()
168        # all chunks are processed
169        if not answers.strip() and not thoughts.strip():  # empty response
170            return await single_api_generate_content(
171                client,
172                status_msg,
173                gemini,
174                params=params,
175                prefix=prefix,
176                silent=silent,
177                retry=retry + 1,
178                max_retries=max_retries,
179                append_grounding=append_grounding,
180                hide_thinking=hide_thinking,
181                **kwargs,
182            )
183        if append_grounding:  # add grounding to the response
184            answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
185            runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
186        if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10:  # short answer in single msg
187            quoted = answers.strip()
188            await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
189        else:  # total length is too long, answers are splitted into multiple messages
190            await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
191
192    except Exception as e:
193        error = f"{e}\n{resp}"
194        logger.error(error)
195        with contextlib.suppress(Exception):
196            await modify_progress(message=status_msg, text=error, force_update=True)
197            [await modify_progress(msg, del_status=True) for msg in sent_messages]
198        if retry + 1 < max_retries:
199            return await single_api_generate_content(
200                client,
201                status_msg,
202                gemini,
203                params=params,
204                prefix=prefix,
205                silent=silent,
206                retry=retry + 1,
207                max_retries=max_retries,
208                append_grounding=append_grounding,
209                hide_thinking=hide_thinking,
210                **kwargs,
211            )
212    return {"texts": answers, "thoughts": thoughts, "sent_messages": sent_messages}
213
214
215def parse_chunk(chunk: types.GenerateContentResponse) -> dict:
216    """Parse gemini response, includes texts, image and websearch."""
217    data = trim_none(chunk.model_dump())
218    data.pop("sdk_http_response", None)
219    data.pop("usage_metadata", None)
220    logger.trace(data)
221    parts = glom(data, "candidates.0.content.parts", default=[]) or []
222    texts = "".join([p.get("text", "") for p in parts if not p.get("thought")])
223    thinking = "".join([p.get("text", "") for p in parts if p.get("thought")])
224    return {
225        "texts": beautify_llm_response(texts, newline_level=2),
226        "thinking": beautify_llm_response(thinking, newline_level=2),
227        "grounding_chunks": glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or [],
228        "grounding_supports": glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or [],
229    }
230
231
232async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
233    urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
234    tasks = [flatten_rediercts(url) for url in urls]
235    try:
236        flatten_urls = await asyncio.gather(*tasks)
237        index2url = flatten_urls
238    except Exception as e:
239        logger.warning(e)
240        index2url = urls
241    logger.trace(f"Grounding URLs: {index2url}")
242    for support in grounding_supports:
243        indices: list[int] = support.get("grounding_chunk_indices", [])
244        logger.trace(f"Add grounding indices: {indices}")
245        indices_with_url = " ".join([f"[[{idx + 1}]]({glom(index2url, str(idx), default='https://www.google.com')})" for idx in indices])
246        if segment := glom(support, "segment.text", default=""):
247            answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
248    for idx, grounding in enumerate(grounding_chunks):
249        if idx > 9:
250            break
251        title = glom(grounding, "web.title", default="Web")
252        url = glom(index2url, str(idx), default="https://www.google.com")
253        if url in answers:
254            answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
255    return answers