diff --git a/src/mcpo/tests/test_main.py b/src/mcpo/tests/test_main.py index e0b8cb3..a7cfe5f 100644 --- a/src/mcpo/tests/test_main.py +++ b/src/mcpo/tests/test_main.py @@ -1,6 +1,6 @@ import pytest from pydantic import BaseModel, Field -from typing import Any, List, Dict +from typing import Any, List, Dict, Union from mcpo.utils.main import _process_schema_property @@ -64,6 +64,17 @@ def test_process_simple_number(): assert result_field.default == expected_field.default +def test_process_null(): + schema = {"type": "null"} + expected_type = None + expected_field = Field(default=..., description="") + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "prop", True + ) + assert result_type == expected_type + assert result_field.default == expected_field.default + + def test_process_unknown_type(): schema = {"type": "unknown"} expected_type = Any @@ -226,8 +237,11 @@ def test_model_caching(): assert len(_model_cache) == 2 # Only two unique models created -def test_multi_type_property(): - schema = {"type": ["string", "number"], "description": "A property with multiple types"} +def test_multi_type_property_with_list(): + schema = { + "type": ["string", "number"], + "description": "A property with multiple types", + } expected_field = Field(default=..., description="A property with multiple types") result_type, result_field = _process_schema_property( _model_cache, schema, "test", "multi_type", True @@ -243,3 +257,56 @@ def test_multi_type_property(): # Check field properties assert result_field.default == expected_field.default assert result_field.description == expected_field.description + + +def test_multi_type_property_with_any_of(): + schema = { + "anyOf": [ + { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the function to call", + }, + "arguments": { + "type": "string", + "description": "The arguments to pass to the function, as a JSON string", + "default": "{}", + }, + }, + "required": ["name"], + }, + { + "type": "object", + "properties": { + "function_id": { + "type": "int", + "description": "The id of the function to call", + }, + }, + "required": ["function_id"], + }, + { + "type": "string", + "enum": ["auto", "none"], + "description": "Control function calling behavior", + }, + ], + "description": "A property with multiple types", + } + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "multi_type", True + ) + + # Check if the resulting type is a Union + assert result_type.__origin__ == Union + + # Check if the Union has the correct number of types + assert len(result_type.__args__) == 3 + assert len(result_type.__args__[0].model_fields) == 2 + assert len(result_type.__args__[1].model_fields) == 1 + assert result_type.__args__[2] is str + + # assert result_field parameter config + assert result_field.description == "A property with multiple types" diff --git a/src/mcpo/utils/main.py b/src/mcpo/utils/main.py index f828150..e15fe7d 100644 --- a/src/mcpo/utils/main.py +++ b/src/mcpo/utils/main.py @@ -2,16 +2,19 @@ import json from typing import Any, Dict, ForwardRef, List, Optional, Type, Union from fastapi import HTTPException + from mcp import ClientSession, types -from mcp.shared.exceptions import McpError from mcp.types import ( - INTERNAL_ERROR, - INVALID_PARAMS, + CallToolResult, + PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, - PARSE_ERROR, - CallToolResult, + INVALID_PARAMS, + INTERNAL_ERROR, ) + +from mcp.shared.exceptions import McpError + from pydantic import Field, create_model from pydantic.fields import FieldInfo @@ -72,6 +75,21 @@ def _process_schema_property( default_value = ... if is_required else prop_schema.get("default", None) pydantic_field = Field(default=default_value, description=prop_desc) + # Handle the case where prop_type is missing but 'anyOf' key exists + # In this case, use data type from 'anyOf' to determine the type hint + if "anyOf" in prop_schema: + type_hints = [] + for i, schema_option in enumerate(prop_schema["anyOf"]): + type_hint, _ = _process_schema_property( + _model_cache, + schema_option, + f"{model_name_prefix}_{prop_name}", + f"choice_{i}", + False, + ) + type_hints.append(type_hint) + return Union[tuple(type_hints)], pydantic_field + # Handle the case where prop_type is a list of types, e.g. ['string', 'number'] if isinstance(prop_type, list): # Create a Union of all the types @@ -147,6 +165,8 @@ def _process_schema_property( return bool, pydantic_field elif prop_type == "number": return float, pydantic_field + elif prop_type == "null": + return None, pydantic_field else: return Any, pydantic_field