feat: support AnyOf and Null

This commit is contained in:
Tsukimaru Oshawott 2025-04-24 17:50:25 +08:00
parent 393f2e724b
commit 603327c853
2 changed files with 101 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, List, Dict from typing import Any, List, Dict, Union
from mcpo.utils.main import _process_schema_property from mcpo.utils.main import _process_schema_property
@ -64,6 +64,17 @@ def test_process_simple_number():
assert result_field.default == expected_field.default assert result_field.default == expected_field.default
def test_process_null():
schema = {"type": "null"}
expected_type = None
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
def test_process_unknown_type(): def test_process_unknown_type():
schema = {"type": "unknown"} schema = {"type": "unknown"}
expected_type = Any expected_type = Any
@ -226,8 +237,11 @@ def test_model_caching():
assert len(_model_cache) == 2 # Only two unique models created assert len(_model_cache) == 2 # Only two unique models created
def test_multi_type_property(): def test_multi_type_property_with_list():
schema = {"type": ["string", "number"], "description": "A property with multiple types"} schema = {
"type": ["string", "number"],
"description": "A property with multiple types",
}
expected_field = Field(default=..., description="A property with multiple types") expected_field = Field(default=..., description="A property with multiple types")
result_type, result_field = _process_schema_property( result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True _model_cache, schema, "test", "multi_type", True
@ -243,3 +257,56 @@ def test_multi_type_property():
# Check field properties # Check field properties
assert result_field.default == expected_field.default assert result_field.default == expected_field.default
assert result_field.description == expected_field.description assert result_field.description == expected_field.description
def test_multi_type_property_with_any_of():
schema = {
"anyOf": [
{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to call",
},
"arguments": {
"type": "string",
"description": "The arguments to pass to the function, as a JSON string",
"default": "{}",
},
},
"required": ["name"],
},
{
"type": "object",
"properties": {
"function_id": {
"type": "int",
"description": "The id of the function to call",
},
},
"required": ["function_id"],
},
{
"type": "string",
"enum": ["auto", "none"],
"description": "Control function calling behavior",
},
],
"description": "A property with multiple types",
}
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
)
# Check if the resulting type is a Union
assert result_type.__origin__ == Union
# Check if the Union has the correct number of types
assert len(result_type.__args__) == 3
assert len(result_type.__args__[0].model_fields) == 2
assert len(result_type.__args__[1].model_fields) == 1
assert result_type.__args__[2] is str
# assert result_field parameter config
assert result_field.description == "A property with multiple types"

View File

@ -3,7 +3,14 @@ from pydantic import create_model, Field
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from mcp import ClientSession, types from mcp import ClientSession, types
from fastapi import HTTPException from fastapi import HTTPException
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR from mcp.types import (
CallToolResult,
PARSE_ERROR,
INVALID_REQUEST,
METHOD_NOT_FOUND,
INVALID_PARAMS,
INTERNAL_ERROR,
)
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
import json import json
@ -58,6 +65,21 @@ def _process_schema_property(
default_value = ... if is_required else prop_schema.get("default", None) default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc) pydantic_field = Field(default=default_value, description=prop_desc)
# Handle the case where prop_type is missing but 'anyOf' key exists
# In this case, use data type from 'anyOf' to determine the type hint
if "anyOf" in prop_schema:
type_hints = []
for i, schema_option in enumerate(prop_schema["anyOf"]):
type_hint, _ = _process_schema_property(
_model_cache,
schema_option,
f"{model_name_prefix}_{prop_name}",
f"choice_{i}",
False,
)
type_hints.append(type_hint)
return Union[tuple(type_hints)], pydantic_field
# Handle the case where prop_type is a list of types, e.g. ['string', 'number'] # Handle the case where prop_type is a list of types, e.g. ['string', 'number']
if isinstance(prop_type, list): if isinstance(prop_type, list):
# Create a Union of all the types # Create a Union of all the types
@ -127,6 +149,8 @@ def _process_schema_property(
return bool, pydantic_field return bool, pydantic_field
elif prop_type == "number": elif prop_type == "number":
return float, pydantic_field return float, pydantic_field
elif prop_type == "null":
return None, pydantic_field
else: else:
return Any, pydantic_field return Any, pydantic_field
@ -174,7 +198,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
) )
response_data = process_tool_response(result) response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response return final_response
except McpError as e: except McpError as e:
@ -222,7 +248,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
) )
response_data = process_tool_response(result) response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response return final_response
except McpError as e: except McpError as e: