From 9d7037b730dc3fff09cecdc3d17d8a444cb0cebd Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:27:38 +0100 Subject: [PATCH] add pydantic model from json --- backend/utils/schemas.py | 104 +++++++++++++++++++++++++++++++++++++++ backend/utils/tools.py | 3 ++ 2 files changed, 107 insertions(+) create mode 100644 backend/utils/schemas.py diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py new file mode 100644 index 000000000..09b24897b --- /dev/null +++ b/backend/utils/schemas.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field, create_model +from typing import Any, Optional, Type + + +def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: + """ + Converts a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + + # Extract the model name from the schema title. + model_name = tool_dict["name"] + schema = tool_dict["parameters"] + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) + for name, prop in schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, json_schema: dict[str, Any], required: list[str] +) -> Any: + """ + Converts a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + + Returns: + A Pydantic field definition. + """ + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: + """ + Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + + type_ = json_schema.get("type") + + if type_ == "string" or type_ == "str": + return str + elif type_ == "integer" or type_ == "int": + return int + elif type_ == "number" or type_ == "float": + return float + elif type_ == "boolean" or type_ == "bool": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + return list[item_type] + else: + return list + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/utils/tools.py b/backend/utils/tools.py index 14519f1be..1a2fea32b 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -6,6 +6,8 @@ from apps.webui.models.tools import Tools from apps.webui.models.users import UserModel from apps.webui.utils import load_toolkit_module_by_id +from utils.schemas import json_schema_to_model + log = logging.getLogger(__name__) @@ -70,6 +72,7 @@ def get_tools( "toolkit_id": tool_id, "callable": callable, "spec": spec, + "pydantic_model": json_schema_to_model(spec), "file_handler": hasattr(module, "file_handler") and module.file_handler, "citation": hasattr(module, "citation") and module.citation, }