From 603327c8539f173e0b2dad5c4ac459f46c2d48e3 Mon Sep 17 00:00:00 2001 From: Tsukimaru Oshawott Date: Thu, 24 Apr 2025 17:50:25 +0800 Subject: [PATCH] feat: support `AnyOf` and `Null` --- src/mcpo/tests/test_main.py | 73 +++++++++++++++++++++++++++++++++++-- src/mcpo/utils/main.py | 34 +++++++++++++++-- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/src/mcpo/tests/test_main.py b/src/mcpo/tests/test_main.py index e0b8cb3..a7cfe5f 100644 --- a/src/mcpo/tests/test_main.py +++ b/src/mcpo/tests/test_main.py @@ -1,6 +1,6 @@ import pytest from pydantic import BaseModel, Field -from typing import Any, List, Dict +from typing import Any, List, Dict, Union from mcpo.utils.main import _process_schema_property @@ -64,6 +64,17 @@ def test_process_simple_number(): assert result_field.default == expected_field.default +def test_process_null(): + schema = {"type": "null"} + expected_type = None + expected_field = Field(default=..., description="") + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "prop", True + ) + assert result_type == expected_type + assert result_field.default == expected_field.default + + def test_process_unknown_type(): schema = {"type": "unknown"} expected_type = Any @@ -226,8 +237,11 @@ def test_model_caching(): assert len(_model_cache) == 2 # Only two unique models created -def test_multi_type_property(): - schema = {"type": ["string", "number"], "description": "A property with multiple types"} +def test_multi_type_property_with_list(): + schema = { + "type": ["string", "number"], + "description": "A property with multiple types", + } expected_field = Field(default=..., description="A property with multiple types") result_type, result_field = _process_schema_property( _model_cache, schema, "test", "multi_type", True @@ -243,3 +257,56 @@ def test_multi_type_property(): # Check field properties assert result_field.default == expected_field.default assert result_field.description == expected_field.description + + +def test_multi_type_property_with_any_of(): + schema = { + "anyOf": [ + { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the function to call", + }, + "arguments": { + "type": "string", + "description": "The arguments to pass to the function, as a JSON string", + "default": "{}", + }, + }, + "required": ["name"], + }, + { + "type": "object", + "properties": { + "function_id": { + "type": "int", + "description": "The id of the function to call", + }, + }, + "required": ["function_id"], + }, + { + "type": "string", + "enum": ["auto", "none"], + "description": "Control function calling behavior", + }, + ], + "description": "A property with multiple types", + } + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "multi_type", True + ) + + # Check if the resulting type is a Union + assert result_type.__origin__ == Union + + # Check if the Union has the correct number of types + assert len(result_type.__args__) == 3 + assert len(result_type.__args__[0].model_fields) == 2 + assert len(result_type.__args__[1].model_fields) == 1 + assert result_type.__args__[2] is str + + # assert result_field parameter config + assert result_field.description == "A property with multiple types" diff --git a/src/mcpo/utils/main.py b/src/mcpo/utils/main.py index dfe5a92..e419c64 100644 --- a/src/mcpo/utils/main.py +++ b/src/mcpo/utils/main.py @@ -3,7 +3,14 @@ from pydantic import create_model, Field from pydantic.fields import FieldInfo from mcp import ClientSession, types from fastapi import HTTPException -from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR +from mcp.types import ( + CallToolResult, + PARSE_ERROR, + INVALID_REQUEST, + METHOD_NOT_FOUND, + INVALID_PARAMS, + INTERNAL_ERROR, +) from mcp.shared.exceptions import McpError import json @@ -58,6 +65,21 @@ def _process_schema_property( default_value = ... if is_required else prop_schema.get("default", None) pydantic_field = Field(default=default_value, description=prop_desc) + # Handle the case where prop_type is missing but 'anyOf' key exists + # In this case, use data type from 'anyOf' to determine the type hint + if "anyOf" in prop_schema: + type_hints = [] + for i, schema_option in enumerate(prop_schema["anyOf"]): + type_hint, _ = _process_schema_property( + _model_cache, + schema_option, + f"{model_name_prefix}_{prop_name}", + f"choice_{i}", + False, + ) + type_hints.append(type_hint) + return Union[tuple(type_hints)], pydantic_field + # Handle the case where prop_type is a list of types, e.g. ['string', 'number'] if isinstance(prop_type, list): # Create a Union of all the types @@ -127,6 +149,8 @@ def _process_schema_property( return bool, pydantic_field elif prop_type == "number": return float, pydantic_field + elif prop_type == "null": + return None, pydantic_field else: return Any, pydantic_field @@ -174,7 +198,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): ) response_data = process_tool_response(result) - final_response = response_data[0] if len(response_data) == 1 else response_data + final_response = ( + response_data[0] if len(response_data) == 1 else response_data + ) return final_response except McpError as e: @@ -222,7 +248,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): ) response_data = process_tool_response(result) - final_response = response_data[0] if len(response_data) == 1 else response_data + final_response = ( + response_data[0] if len(response_data) == 1 else response_data + ) return final_response except McpError as e: