diff --git a/src/mcpo/tests/test_main.py b/src/mcpo/tests/test_main.py index db10617..e0b8cb3 100644 --- a/src/mcpo/tests/test_main.py +++ b/src/mcpo/tests/test_main.py @@ -224,3 +224,22 @@ def test_model_caching(): ) assert result_type3 == result_type1 # Should be the same cached object assert len(_model_cache) == 2 # Only two unique models created + + +def test_multi_type_property(): + schema = {"type": ["string", "number"], "description": "A property with multiple types"} + expected_field = Field(default=..., description="A property with multiple types") + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "multi_type", True + ) + + # Check if the resulting type is a Union + assert str(result_type).startswith("typing.Union[") + + # Check if both types are in the Union + assert str in result_type.__args__ + assert float in result_type.__args__ + + # Check field properties + assert result_field.default == expected_field.default + assert result_field.description == expected_field.description diff --git a/src/mcpo/utils/main.py b/src/mcpo/utils/main.py index 29a71dd..dfe5a92 100644 --- a/src/mcpo/utils/main.py +++ b/src/mcpo/utils/main.py @@ -58,6 +58,22 @@ def _process_schema_property( default_value = ... if is_required else prop_schema.get("default", None) pydantic_field = Field(default=default_value, description=prop_desc) + # Handle the case where prop_type is a list of types, e.g. ['string', 'number'] + if isinstance(prop_type, list): + # Create a Union of all the types + type_hints = [] + for type_option in prop_type: + # Create a temporary schema with the single type and process it + temp_schema = dict(prop_schema) + temp_schema["type"] = type_option + type_hint, _ = _process_schema_property( + _model_cache, temp_schema, model_name_prefix, prop_name, False + ) + type_hints.append(type_hint) + + # Return a Union of all possible types + return Union[tuple(type_hints)], pydantic_field + if prop_type == "object": nested_properties = prop_schema.get("properties", {}) nested_required = prop_schema.get("required", [])