main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import contextlib
5import json
6from json import JSONDecodeError
7
8from glom import glom
9from google import genai
10from google.genai import types
11from jsonschema import ValidationError, validate
12from loguru import logger
13from pyrogram.client import Client
14from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
15from pyrogram.types import Message, ReplyParameters
16
17from ai.texts.contexts import get_gemini_contexts
18from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
19from config import AI, PROXY, TEXT_LENGTH
20from messages.progress import modify_progress
21from messages.utils import blockquote, count_without_entities, quote, smart_split
22from networking import flatten_rediercts
23from utils import number_to_emoji, strings_list
24
25
26async def gemini_chat_completion(
27 client: Client,
28 message: Message,
29 *,
30 prefix: str = "",
31 model_id: str = AI.GEMINI_MODEL_ID,
32 model_name: str = AI.GEMINI_MODEL_ID,
33 gemini_base_url: str = AI.GEMINI_BASE_URL,
34 gemini_api_keys: str = AI.GEMINI_API_KEYS,
35 gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
36 gemini_generate_content_config: str | dict = "",
37 gemini_proxy: str | None = PROXY.GOOGLE,
38 gemini_append_grounding: bool = True,
39 additional_contexts: list[dict] | None = None, # additional contexts to append to the contexts
40 skills: str = "",
41 hide_thinking: bool = False,
42 add_sender: bool | None = None,
43 silent: bool = False,
44 max_retries: int = 3,
45 **kwargs,
46) -> dict:
47 """Get OpenAI Chat Completions.
48
49 Returns:
50 dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
51 """
52 if not prefix:
53 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
54
55 if silent or not kwargs.get("show_progress"): # noqa: SIM108
56 status_msg = None
57 else:
58 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
59
60 sent_messages = [status_msg]
61 for api_key in strings_list(gemini_api_keys, shuffle=True):
62 try:
63 http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
64 gemini = genai.Client(api_key=api_key, http_options=http_options)
65 params: dict = {
66 "model": model_id,
67 "contents": await get_gemini_contexts(
68 client,
69 message,
70 gemini,
71 add_sender=add_sender,
72 additional_contexts=additional_contexts,
73 ),
74 }
75 if skills:
76 gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
77 if conf := literal_eval(gemini_generate_content_config):
78 params["config"] = conf
79 logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
80 resp = await single_api_generate_content(
81 client,
82 status_msg,
83 gemini,
84 params=params,
85 prefix=prefix,
86 silent=silent,
87 max_retries=max_retries,
88 append_grounding=gemini_append_grounding,
89 hide_thinking=hide_thinking,
90 **kwargs,
91 )
92 if resp.get("texts"):
93 if not is_valid_response(resp, glom(params, "config.responseJsonSchema", default={})):
94 continue
95 sent_messages.extend(resp.get("sent_messages", []))
96 return {
97 "success": True,
98 "texts": resp["texts"],
99 "thoughts": resp["thoughts"],
100 "prefix": prefix,
101 "model_name": model_name,
102 "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
103 }
104 except Exception as e:
105 logger.error(f"Gemini API error: {e}")
106 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
107 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
108
109
110async def single_api_generate_content(
111 client: Client,
112 status_msg: Message | None,
113 gemini: genai.Client,
114 params: dict,
115 *,
116 prefix: str = "",
117 retry: int = 0,
118 max_retries: int = 3,
119 append_grounding: bool = True,
120 hide_thinking: bool = False,
121 silent: bool = False,
122 **kwargs,
123) -> dict:
124 """Get Gemini Chat Completions via single API.
125
126 Returns:
127 dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
128 """
129 if retry > max_retries:
130 return {"texts": "", "thoughts": "", "sent_messages": []}
131 answers = "" # all model responses
132 thoughts = "" # all model thoughts
133 runtime_texts = "" # for a single telegram message
134 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
135 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
136 sent_messages = []
137 resp = {}
138 try:
139 reasoning_chat_flag = None # 是否是推理模型
140 is_reasoning = False # 是否正在推理
141 async for chunk in await gemini.aio.models.generate_content_stream(**params):
142 resp = parse_chunk(chunk)
143 chunk_answer = resp.get("texts", "")
144 chunk_thinking = resp.get("thinking", "")
145 if chunk_thinking:
146 reasoning_chat_flag = True
147 if chunk_thinking and not is_reasoning: # 首次收到推理内容
148 is_reasoning = True
149 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
150 elif chunk_thinking and is_reasoning: # 收到推理内容且正在思考
151 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
152 elif reasoning_chat_flag is True and is_reasoning: # Receiving response, close reasoning flag
153 is_reasoning = False
154 runtime_texts = chunk_answer.lstrip()
155 else:
156 runtime_texts += chunk_answer
157
158 thoughts += chunk_thinking
159 answers += chunk_answer
160 if hide_thinking and is_reasoning:
161 continue
162
163 runtime_texts = beautify_llm_response(runtime_texts)
164 length = await count_without_entities(prefix + runtime_texts)
165 if length <= TEXT_LENGTH:
166 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
167 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
168 else: # answers is too long, split it into multiple messages
169 parts = await smart_split(prefix + runtime_texts)
170 if len(parts) == 1:
171 continue
172 if is_reasoning:
173 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
174 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
175 else:
176 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
177 runtime_texts = parts[-1] # keep the last part
178 if not silent:
179 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
180 sent_messages.append(status_msg)
181 status_mid = status_msg.id
182
183 await gemini.aio.aclose()
184 # all chunks are processed
185 if not answers.strip() and not thoughts.strip(): # empty response
186 return await single_api_generate_content(
187 client,
188 status_msg,
189 gemini,
190 params=params,
191 prefix=prefix,
192 silent=silent,
193 retry=retry + 1,
194 max_retries=max_retries,
195 append_grounding=append_grounding,
196 hide_thinking=hide_thinking,
197 **kwargs,
198 )
199 if append_grounding: # add grounding to the response
200 answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
201 runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
202 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
203 quoted = answers.strip()
204 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
205 else: # total length is too long, answers are splitted into multiple messages
206 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
207
208 except Exception as e:
209 error = f"{e}\n{resp}"
210 logger.error(error)
211 with contextlib.suppress(Exception):
212 await modify_progress(message=status_msg, text=error, force_update=True)
213 [await modify_progress(msg, del_status=True) for msg in sent_messages]
214 if retry + 1 < max_retries:
215 return await single_api_generate_content(
216 client,
217 status_msg,
218 gemini,
219 params=params,
220 prefix=prefix,
221 silent=silent,
222 retry=retry + 1,
223 max_retries=max_retries,
224 append_grounding=append_grounding,
225 hide_thinking=hide_thinking,
226 **kwargs,
227 )
228 return {}
229 return {"texts": answers, "thoughts": thoughts, "sent_messages": sent_messages}
230
231
232def parse_chunk(chunk: types.GenerateContentResponse) -> dict:
233 """Parse gemini response, includes texts, image and websearch."""
234 data = trim_none(chunk.model_dump())
235 data.pop("sdk_http_response", None)
236 data.pop("usage_metadata", None)
237 logger.trace(data)
238 parts = glom(data, "candidates.0.content.parts", default=[]) or []
239 texts = "".join([p.get("text", "") for p in parts if not p.get("thought")])
240 thinking = "".join([p.get("text", "") for p in parts if p.get("thought")])
241 return {
242 "texts": beautify_llm_response(texts, newline_level=2),
243 "thinking": beautify_llm_response(thinking, newline_level=2),
244 "grounding_chunks": glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or [],
245 "grounding_supports": glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or [],
246 }
247
248
249async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
250 urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
251 tasks = [flatten_rediercts(url) for url in urls]
252 try:
253 flatten_urls = await asyncio.gather(*tasks)
254 index2url = flatten_urls
255 except Exception as e:
256 logger.warning(e)
257 index2url = urls
258 logger.trace(f"Grounding URLs: {index2url}")
259 for support in grounding_supports:
260 indices: list[int] = support.get("grounding_chunk_indices", [])
261 logger.trace(f"Add grounding indices: {indices}")
262 indices_with_url = " ".join([f"[[{idx + 1}]]({glom(index2url, str(idx), default='https://www.google.com')})" for idx in indices])
263 if segment := glom(support, "segment.text", default=""):
264 answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
265 for idx, grounding in enumerate(grounding_chunks):
266 if idx > 9:
267 break
268 title = glom(grounding, "web.title", default="Web")
269 url = glom(index2url, str(idx), default="https://www.google.com")
270 if url in answers:
271 answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
272 return answers
273
274
275def is_valid_response(resp: dict, schema: dict) -> bool:
276 """Check if the response is valid."""
277 if not schema:
278 return bool(resp.get("texts"))
279 if not resp.get("texts"):
280 return False
281 try:
282 data = json.loads(resp["texts"])
283 validate(instance=data, schema=schema)
284 except (JSONDecodeError, ValidationError) as e:
285 logger.error(f"Invalid JSONSchema response: {e}")
286 return False
287 return True