main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import contextlib
5
6from glom import glom
7from google import genai
8from google.genai import types
9from loguru import logger
10from pyrogram.client import Client
11from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
12from pyrogram.types import Message, ReplyParameters
13
14from ai.texts.contexts import get_gemini_contexts
15from ai.utils import BOT_TIPS, EMOJI_REASONING_BEGIN, EMOJI_TEXT_BOT, beautify_llm_response, literal_eval, load_skills, trim_none
16from config import AI, PROXY, TEXT_LENGTH
17from messages.progress import modify_progress
18from messages.utils import blockquote, count_without_entities, quote, smart_split
19from networking import flatten_rediercts
20from utils import number_to_emoji, strings_list
21
22
23async def gemini_chat_completion(
24 client: Client,
25 message: Message,
26 *,
27 prefix: str = "",
28 model_id: str = AI.GEMINI_MODEL_ID,
29 model_name: str = AI.GEMINI_MODEL_ID,
30 gemini_base_url: str = AI.GEMINI_BASE_URL,
31 gemini_api_keys: str = AI.GEMINI_API_KEYS,
32 gemini_default_headers: str | dict = AI.GEMINI_DEFAULT_HEADERS,
33 gemini_generate_content_config: str | dict = "",
34 gemini_proxy: str | None = PROXY.GOOGLE,
35 gemini_append_grounding: bool = True,
36 skills: str = "",
37 hide_thinking: bool = False,
38 silent: bool = False,
39 max_retries: int = 3,
40 **kwargs,
41) -> dict:
42 """Get OpenAI Chat Completions.
43
44 Returns:
45 dict: {"texts": str, "thoughts": str, "prefix": str, "model_name": str, "sent_messages": list[Message]}
46 """
47 if not prefix:
48 prefix = f"{EMOJI_TEXT_BOT}**{model_name}**:{BOT_TIPS}\n"
49
50 if silent or not kwargs.get("show_progress"): # noqa: SIM108
51 status_msg = None
52 else:
53 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_TEXT_BOT}**{model_name}**: 思考中...", quote=True)
54
55 sent_messages = [status_msg]
56 for api_key in strings_list(gemini_api_keys, shuffle=True):
57 try:
58 http_options = types.HttpOptions(base_url=gemini_base_url, async_client_args={"proxy": gemini_proxy}, headers=literal_eval(gemini_default_headers))
59 gemini = genai.Client(api_key=api_key, http_options=http_options)
60 params: dict = {"model": model_id, "contents": await get_gemini_contexts(client, message, gemini)}
61 if skills:
62 gemini_generate_content_config = literal_eval(gemini_generate_content_config) | {"system_instruction": await load_skills(skills)}
63 if conf := literal_eval(gemini_generate_content_config):
64 params["config"] = conf
65 logger.debug(f"genai.Client().models.generate_content_stream(**{params})")
66 resp = await single_api_generate_content(
67 client,
68 status_msg,
69 gemini,
70 params=params,
71 prefix=prefix,
72 silent=silent,
73 max_retries=max_retries,
74 append_grounding=gemini_append_grounding,
75 hide_thinking=hide_thinking,
76 **kwargs,
77 )
78 if resp.get("texts"):
79 sent_messages.extend(resp.get("sent_messages", []))
80 return {
81 "success": True,
82 "texts": resp["texts"],
83 "thoughts": resp["thoughts"],
84 "prefix": prefix,
85 "model_name": model_name,
86 "sent_messages": [m for m in sent_messages if isinstance(m, Message)],
87 }
88 except Exception as e:
89 logger.error(f"Gemini API error: {e}")
90 await modify_progress(status_msg, text=f"❌{e}", force_update=True, **kwargs)
91 return {"progress": status_msg} if isinstance(status_msg, Message) else {}
92
93
94async def single_api_generate_content(
95 client: Client,
96 status_msg: Message | None,
97 gemini: genai.Client,
98 params: dict,
99 *,
100 prefix: str = "",
101 retry: int = 0,
102 max_retries: int = 3,
103 append_grounding: bool = True,
104 hide_thinking: bool = False,
105 silent: bool = False,
106 **kwargs,
107) -> dict:
108 """Get Gemini Chat Completions via single API.
109
110 Returns:
111 dict: {"texts": str, "thoughts": str, "sent_messages": list[Message]}
112 """
113 if retry > max_retries:
114 return {"texts": "", "thoughts": "", "sent_messages": []}
115 answers = "" # all model responses
116 thoughts = "" # all model thoughts
117 runtime_texts = "" # for a single telegram message
118 status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
119 status_mid = status_msg.id if isinstance(status_msg, Message) else 0
120 sent_messages = []
121 resp = {}
122 try:
123 reasoning_chat_flag = None # 是否是推理模型
124 is_reasoning = False # 是否正在推理
125 async for chunk in await gemini.aio.models.generate_content_stream(**params):
126 resp = parse_chunk(chunk)
127 chunk_answer = resp.get("texts", "")
128 chunk_thinking = resp.get("thinking", "")
129 if chunk_thinking:
130 reasoning_chat_flag = True
131 if chunk_thinking and not is_reasoning: # 首次收到推理内容
132 is_reasoning = True
133 runtime_texts += quote(f"{EMOJI_REASONING_BEGIN}{chunk_thinking.lstrip()}")
134 elif chunk_thinking and is_reasoning: # 收到推理内容且正在思考
135 runtime_texts += chunk_thinking.replace("\n", f"\n{BLOCKQUOTE_DELIM}")
136 elif reasoning_chat_flag is True and is_reasoning: # Receiving response, close reasoning flag
137 is_reasoning = False
138 runtime_texts = chunk_answer.lstrip()
139 else:
140 runtime_texts += chunk_answer
141
142 thoughts += chunk_thinking
143 answers += chunk_answer
144 if hide_thinking and is_reasoning:
145 continue
146
147 runtime_texts = beautify_llm_response(runtime_texts)
148 length = await count_without_entities(prefix + runtime_texts)
149 if length <= TEXT_LENGTH:
150 if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
151 await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
152 else: # answers is too long, split it into multiple messages
153 parts = await smart_split(prefix + runtime_texts)
154 if len(parts) == 1:
155 continue
156 if is_reasoning:
157 runtime_texts = quote(f"{EMOJI_REASONING_BEGIN}{parts[-1].lstrip()}") # remove previous thinking
158 await modify_progress(message=status_msg, text=parts[0], force_update=True) # force send the first part
159 else:
160 await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
161 runtime_texts = parts[-1] # keep the last part
162 if not silent:
163 status_msg = await client.send_message(status_cid, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
164 sent_messages.append(status_msg)
165 status_mid = status_msg.id
166
167 await gemini.aio.aclose()
168 # all chunks are processed
169 if not answers.strip() and not thoughts.strip(): # empty response
170 return await single_api_generate_content(
171 client,
172 status_msg,
173 gemini,
174 params=params,
175 prefix=prefix,
176 silent=silent,
177 retry=retry + 1,
178 max_retries=max_retries,
179 append_grounding=append_grounding,
180 hide_thinking=hide_thinking,
181 **kwargs,
182 )
183 if append_grounding: # add grounding to the response
184 answers = await add_grounding_results(answers, resp["grounding_chunks"], resp["grounding_supports"])
185 runtime_texts = await add_grounding_results(runtime_texts, resp["grounding_chunks"], resp["grounding_supports"])
186 if await count_without_entities(prefix + answers) <= TEXT_LENGTH - 10: # short answer in single msg
187 quoted = answers.strip()
188 await modify_progress(message=status_msg, text=f"{prefix}{blockquote(quoted)}", force_update=True)
189 else: # total length is too long, answers are splitted into multiple messages
190 await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
191
192 except Exception as e:
193 error = f"{e}\n{resp}"
194 logger.error(error)
195 with contextlib.suppress(Exception):
196 await modify_progress(message=status_msg, text=error, force_update=True)
197 [await modify_progress(msg, del_status=True) for msg in sent_messages]
198 if retry + 1 < max_retries:
199 return await single_api_generate_content(
200 client,
201 status_msg,
202 gemini,
203 params=params,
204 prefix=prefix,
205 silent=silent,
206 retry=retry + 1,
207 max_retries=max_retries,
208 append_grounding=append_grounding,
209 hide_thinking=hide_thinking,
210 **kwargs,
211 )
212 return {"texts": answers, "thoughts": thoughts, "sent_messages": sent_messages}
213
214
215def parse_chunk(chunk: types.GenerateContentResponse) -> dict:
216 """Parse gemini response, includes texts, image and websearch."""
217 data = trim_none(chunk.model_dump())
218 data.pop("sdk_http_response", None)
219 data.pop("usage_metadata", None)
220 logger.trace(data)
221 parts = glom(data, "candidates.0.content.parts", default=[]) or []
222 texts = "".join([p.get("text", "") for p in parts if not p.get("thought")])
223 thinking = "".join([p.get("text", "") for p in parts if p.get("thought")])
224 return {
225 "texts": beautify_llm_response(texts, newline_level=2),
226 "thinking": beautify_llm_response(thinking, newline_level=2),
227 "grounding_chunks": glom(data, "candidates.0.grounding_metadata.grounding_chunks", default=[]) or [],
228 "grounding_supports": glom(data, "candidates.0.grounding_metadata.grounding_supports", default=[]) or [],
229 }
230
231
232async def add_grounding_results(answers: str, grounding_chunks: list[dict], grounding_supports: list[dict]) -> str:
233 urls = [glom(chunk, "web.uri", default="https://www.google.com") for chunk in grounding_chunks]
234 tasks = [flatten_rediercts(url) for url in urls]
235 try:
236 flatten_urls = await asyncio.gather(*tasks)
237 index2url = flatten_urls
238 except Exception as e:
239 logger.warning(e)
240 index2url = urls
241 logger.trace(f"Grounding URLs: {index2url}")
242 for support in grounding_supports:
243 indices: list[int] = support.get("grounding_chunk_indices", [])
244 logger.trace(f"Add grounding indices: {indices}")
245 indices_with_url = " ".join([f"[[{idx + 1}]]({glom(index2url, str(idx), default='https://www.google.com')})" for idx in indices])
246 if segment := glom(support, "segment.text", default=""):
247 answers = answers.replace(segment, f"{segment}{indices_with_url}", 1)
248 for idx, grounding in enumerate(grounding_chunks):
249 if idx > 9:
250 break
251 title = glom(grounding, "web.title", default="Web")
252 url = glom(index2url, str(idx), default="https://www.google.com")
253 if url in answers:
254 answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
255 return answers