Commit 45fb9de
Changed files (4)
src/ai/texts/gemini.py
@@ -2,10 +2,13 @@
# -*- coding: utf-8 -*-
import asyncio
import contextlib
+import json
+from json import JSONDecodeError
from glom import glom
from google import genai
from google.genai import types
+from jsonschema import ValidationError, validate
from loguru import logger
from pyrogram.client import Client
from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
@@ -76,6 +79,8 @@ async def gemini_chat_completion(
**kwargs,
)
if resp.get("texts"):
+ if not is_valid_response(resp, glom(params, "config.responseJsonSchema", default={})):
+ continue
sent_messages.extend(resp.get("sent_messages", []))
return {
"success": True,
@@ -253,3 +258,18 @@ async def add_grounding_results(answers: str, grounding_chunks: list[dict], grou
if url in answers:
answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
return answers
+
+
+def is_valid_response(resp: dict, schema: dict) -> bool:
+ """Check if the response is valid."""
+ if not schema:
+ return bool(resp.get("texts"))
+ if not resp.get("texts"):
+ return False
+ try:
+ data = json.loads(resp["texts"])
+ validate(instance=data, schema=schema)
+ except (JSONDecodeError, ValidationError) as e:
+ logger.error(f"Invalid JSONSchema response: {e}")
+ return False
+ return True
src/ai/texts/openai_response.py
@@ -2,9 +2,12 @@
# -*- coding: utf-8 -*-
import contextlib
import hashlib
+import json
+from json import JSONDecodeError
from typing import Literal
from glom import Coalesce, flatten, glom
+from jsonschema import ValidationError, validate
from loguru import logger
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pyrogram.client import Client
@@ -113,7 +116,7 @@ async def openai_responses_api(
max_retries=max_retries,
**kwargs,
)
- if not resp.get("texts"):
+ if not is_valid_response(resp, glom(params, "text.format.schema", default={})):
continue
sent_messages.extend(resp.get("sent_messages", []))
sent_messages = [m for m in sent_messages if isinstance(m, Message)]
@@ -323,3 +326,18 @@ def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> s
if link.startswith("http"):
answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
return answers.strip()
+
+
+def is_valid_response(resp: dict, schema: dict) -> bool:
+ """Check if the response is valid."""
+ if not schema:
+ return bool(resp.get("texts"))
+ if not resp.get("texts"):
+ return False
+ try:
+ data = json.loads(resp["texts"])
+ validate(instance=data, schema=schema)
+ except (JSONDecodeError, ValidationError) as e:
+ logger.error(f"Invalid JSONSchema response: {e}")
+ return False
+ return True
pyproject.toml
@@ -18,6 +18,7 @@ dependencies = [
"httpx-aiohttp==0.1.12",
"httpx-curl-cffi==0.1.5",
"httpx[http2,socks]==0.28.1",
+ "jsonschema==4.26.0",
"loguru==0.7.3",
"markdown==3.10.2",
"markitdown[docx,pdf,pptx,xls,xlsx]",
uv.lock
@@ -224,6 +224,7 @@ dependencies = [
{ name = "httpx", extra = ["http2", "socks"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "httpx-aiohttp", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "httpx-curl-cffi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "jsonschema", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "loguru", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "markdown", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "markitdown", extra = ["docx", "pdf", "pptx", "xls", "xlsx"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -275,6 +276,7 @@ requires-dist = [
{ name = "httpx", extras = ["http2", "socks"], specifier = "==0.28.1" },
{ name = "httpx-aiohttp", specifier = "==0.1.12" },
{ name = "httpx-curl-cffi", specifier = "==0.1.5" },
+ { name = "jsonschema", specifier = "==4.26.0" },
{ name = "loguru", specifier = "==0.7.3" },
{ name = "markdown", specifier = "==3.10.2" },
{ name = "markitdown", extras = ["docx", "pdf", "pptx", "xls", "xlsx"], git = "https://github.com/benny-dou/markitdown.git?subdirectory=packages%2Fmarkitdown" },
@@ -879,6 +881,33 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" },
]
+[[package]]
+name = "jsonschema"
+version = "4.26.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "attrs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "jsonschema-specifications", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "referencing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "rpds-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" },
+]
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2025.9.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "referencing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" },
+]
+
[[package]]
name = "loguru"
version = "0.7.3"
@@ -1547,6 +1576,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/17/46/4d2c1ce9457ddd90913084085be5add1cb040f9bb41717adb89554a9c9d9/quickchart_io-2.0.0-py3-none-any.whl", hash = "sha256:c44b5fb4d6e957fb85db0926e691684795e9fe5d6819d33f2daea795a0f6a36b", size = 5122, upload-time = "2022-09-24T22:40:32.94Z" },
]
+[[package]]
+name = "referencing"
+version = "0.37.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "attrs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "rpds-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" },
+]
+
[[package]]
name = "requests"
version = "2.34.0"
@@ -1575,6 +1617,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
]
+[[package]]
+name = "rpds-py"
+version = "0.30.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2", size = 374887, upload-time = "2025-11-30T20:22:41.812Z" },
+ { url = "https://files.pythonhosted.org/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8", size = 358904, upload-time = "2025-11-30T20:22:43.479Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/de/f7192e12b21b9e9a68a6d0f249b4af3fdcdff8418be0767a627564afa1f1/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9027da1ce107104c50c81383cae773ef5c24d296dd11c99e2629dbd7967a20c6", size = 394025, upload-time = "2025-11-30T20:22:50.196Z" },
+ { url = "https://files.pythonhosted.org/packages/5f/60/525a50f45b01d70005403ae0e25f43c0384369ad24ffe46e8d9068b50086/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:945dccface01af02675628334f7cf49c2af4c1c904748efc5cf7bbdf0b579f95", size = 563020, upload-time = "2025-11-30T20:22:58.2Z" },
+ { url = "https://files.pythonhosted.org/packages/ff/1b/b10de890a0def2a319a2626334a7f0ae388215eb60914dbac8a3bae54435/rpds_py-0.30.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:eb0b93f2e5c2189ee831ee43f156ed34e2a89a78a66b98cadad955972548be5a", size = 364443, upload-time = "2025-11-30T20:23:04.878Z" },
+ { url = "https://files.pythonhosted.org/packages/0d/bf/27e39f5971dc4f305a4fb9c672ca06f290f7c4e261c568f3dea16a410d47/rpds_py-0.30.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:922e10f31f303c7c920da8981051ff6d8c1a56207dbdf330d9047f6d30b70e5e", size = 353375, upload-time = "2025-11-30T20:23:06.342Z" },
+ { url = "https://files.pythonhosted.org/packages/60/ca/780cf3b1a32b18c0f05c441958d3758f02544f1d613abf9488cd78876378/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a1234d8febafdfd33a42d97da7a43f5dcb120c1060e352a3fbc0c6d36e2083", size = 383843, upload-time = "2025-11-30T20:23:14.638Z" },
+ { url = "https://files.pythonhosted.org/packages/6d/61/21b8c41f68e60c8cc3b2e25644f0e3681926020f11d06ab0b78e3c6bbff1/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c5f36a861bc4b7da6516dbdf302c55313afa09b81931e8280361a4f6c9a2d27", size = 555806, upload-time = "2025-11-30T20:23:22.488Z" },
+]
+
[[package]]
name = "s3transfer"
version = "0.14.0"