Commit 45fb9de

benny-dou <60535774+benny-dou@users.noreply.github.com>
2026-05-17 06:18:47
feat(ai): add JSON Schema validation for LLM responses
1 parent e034f4f
src/ai/texts/gemini.py
@@ -2,10 +2,13 @@
 # -*- coding: utf-8 -*-
 import asyncio
 import contextlib
+import json
+from json import JSONDecodeError
 
 from glom import glom
 from google import genai
 from google.genai import types
+from jsonschema import ValidationError, validate
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.parser.markdown import BLOCKQUOTE_DELIM
@@ -76,6 +79,8 @@ async def gemini_chat_completion(
                 **kwargs,
             )
             if resp.get("texts"):
+                if not is_valid_response(resp, glom(params, "config.responseJsonSchema", default={})):
+                    continue
                 sent_messages.extend(resp.get("sent_messages", []))
                 return {
                     "success": True,
@@ -253,3 +258,18 @@ async def add_grounding_results(answers: str, grounding_chunks: list[dict], grou
         if url in answers:
             answers += f"\n{number_to_emoji(idx + 1)}[{title}]({url})"
     return answers
+
+
+def is_valid_response(resp: dict, schema: dict) -> bool:
+    """Check if the response is valid."""
+    if not schema:
+        return bool(resp.get("texts"))
+    if not resp.get("texts"):
+        return False
+    try:
+        data = json.loads(resp["texts"])
+        validate(instance=data, schema=schema)
+    except (JSONDecodeError, ValidationError) as e:
+        logger.error(f"Invalid JSONSchema response: {e}")
+        return False
+    return True
src/ai/texts/openai_response.py
@@ -2,9 +2,12 @@
 # -*- coding: utf-8 -*-
 import contextlib
 import hashlib
+import json
+from json import JSONDecodeError
 from typing import Literal
 
 from glom import Coalesce, flatten, glom
+from jsonschema import ValidationError, validate
 from loguru import logger
 from openai import AsyncOpenAI, DefaultAsyncHttpxClient
 from pyrogram.client import Client
@@ -113,7 +116,7 @@ async def openai_responses_api(
                 max_retries=max_retries,
                 **kwargs,
             )
-            if not resp.get("texts"):
+            if not is_valid_response(resp, glom(params, "text.format.schema", default={})):
                 continue
             sent_messages.extend(resp.get("sent_messages", []))
             sent_messages = [m for m in sent_messages if isinstance(m, Message)]
@@ -323,3 +326,18 @@ def add_tool_call_results_to_response(tool_calls: list[dict], answers: str) -> s
         if link.startswith("http"):
             answers += f"\n{number_to_emoji(idx + 1)} [{title}]({link})"
     return answers.strip()
+
+
+def is_valid_response(resp: dict, schema: dict) -> bool:
+    """Check if the response is valid."""
+    if not schema:
+        return bool(resp.get("texts"))
+    if not resp.get("texts"):
+        return False
+    try:
+        data = json.loads(resp["texts"])
+        validate(instance=data, schema=schema)
+    except (JSONDecodeError, ValidationError) as e:
+        logger.error(f"Invalid JSONSchema response: {e}")
+        return False
+    return True
pyproject.toml
@@ -18,6 +18,7 @@ dependencies = [
   "httpx-aiohttp==0.1.12",
   "httpx-curl-cffi==0.1.5",
   "httpx[http2,socks]==0.28.1",
+  "jsonschema==4.26.0",
   "loguru==0.7.3",
   "markdown==3.10.2",
   "markitdown[docx,pdf,pptx,xls,xlsx]",
uv.lock
@@ -224,6 +224,7 @@ dependencies = [
     { name = "httpx", extra = ["http2", "socks"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
     { name = "httpx-aiohttp", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
     { name = "httpx-curl-cffi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+    { name = "jsonschema", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
     { name = "loguru", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
     { name = "markdown", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
     { name = "markitdown", extra = ["docx", "pdf", "pptx", "xls", "xlsx"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -275,6 +276,7 @@ requires-dist = [
     { name = "httpx", extras = ["http2", "socks"], specifier = "==0.28.1" },
     { name = "httpx-aiohttp", specifier = "==0.1.12" },
     { name = "httpx-curl-cffi", specifier = "==0.1.5" },
+    { name = "jsonschema", specifier = "==4.26.0" },
     { name = "loguru", specifier = "==0.7.3" },
     { name = "markdown", specifier = "==3.10.2" },
     { name = "markitdown", extras = ["docx", "pdf", "pptx", "xls", "xlsx"], git = "https://github.com/benny-dou/markitdown.git?subdirectory=packages%2Fmarkitdown" },
@@ -879,6 +881,33 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" },
 ]
 
+[[package]]
+name = "jsonschema"
+version = "4.26.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "attrs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+    { name = "jsonschema-specifications", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+    { name = "referencing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+    { name = "rpds-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" },
+]
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2025.9.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "referencing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" },
+]
+
 [[package]]
 name = "loguru"
 version = "0.7.3"
@@ -1547,6 +1576,19 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/17/46/4d2c1ce9457ddd90913084085be5add1cb040f9bb41717adb89554a9c9d9/quickchart_io-2.0.0-py3-none-any.whl", hash = "sha256:c44b5fb4d6e957fb85db0926e691684795e9fe5d6819d33f2daea795a0f6a36b", size = 5122, upload-time = "2022-09-24T22:40:32.94Z" },
 ]
 
+[[package]]
+name = "referencing"
+version = "0.37.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "attrs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+    { name = "rpds-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" },
+]
+
 [[package]]
 name = "requests"
 version = "2.34.0"
@@ -1575,6 +1617,22 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
 ]
 
+[[package]]
+name = "rpds-py"
+version = "0.30.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2", size = 374887, upload-time = "2025-11-30T20:22:41.812Z" },
+    { url = "https://files.pythonhosted.org/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8", size = 358904, upload-time = "2025-11-30T20:22:43.479Z" },
+    { url = "https://files.pythonhosted.org/packages/b7/de/f7192e12b21b9e9a68a6d0f249b4af3fdcdff8418be0767a627564afa1f1/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9027da1ce107104c50c81383cae773ef5c24d296dd11c99e2629dbd7967a20c6", size = 394025, upload-time = "2025-11-30T20:22:50.196Z" },
+    { url = "https://files.pythonhosted.org/packages/5f/60/525a50f45b01d70005403ae0e25f43c0384369ad24ffe46e8d9068b50086/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:945dccface01af02675628334f7cf49c2af4c1c904748efc5cf7bbdf0b579f95", size = 563020, upload-time = "2025-11-30T20:22:58.2Z" },
+    { url = "https://files.pythonhosted.org/packages/ff/1b/b10de890a0def2a319a2626334a7f0ae388215eb60914dbac8a3bae54435/rpds_py-0.30.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:eb0b93f2e5c2189ee831ee43f156ed34e2a89a78a66b98cadad955972548be5a", size = 364443, upload-time = "2025-11-30T20:23:04.878Z" },
+    { url = "https://files.pythonhosted.org/packages/0d/bf/27e39f5971dc4f305a4fb9c672ca06f290f7c4e261c568f3dea16a410d47/rpds_py-0.30.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:922e10f31f303c7c920da8981051ff6d8c1a56207dbdf330d9047f6d30b70e5e", size = 353375, upload-time = "2025-11-30T20:23:06.342Z" },
+    { url = "https://files.pythonhosted.org/packages/60/ca/780cf3b1a32b18c0f05c441958d3758f02544f1d613abf9488cd78876378/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a1234d8febafdfd33a42d97da7a43f5dcb120c1060e352a3fbc0c6d36e2083", size = 383843, upload-time = "2025-11-30T20:23:14.638Z" },
+    { url = "https://files.pythonhosted.org/packages/6d/61/21b8c41f68e60c8cc3b2e25644f0e3681926020f11d06ab0b78e3c6bbff1/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c5f36a861bc4b7da6516dbdf302c55313afa09b81931e8280361a4f6c9a2d27", size = 555806, upload-time = "2025-11-30T20:23:22.488Z" },
+]
+
 [[package]]
 name = "s3transfer"
 version = "0.14.0"