feat: support AnyOf and Null

This commit is contained in:
Tsukimaru Oshawott 2025-04-24 17:50:25 +08:00
parent 393f2e724b
commit 603327c853
2 changed files with 101 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import pytest
from pydantic import BaseModel, Field
from typing import Any, List, Dict
from typing import Any, List, Dict, Union
from mcpo.utils.main import _process_schema_property
@ -64,6 +64,17 @@ def test_process_simple_number():
assert result_field.default == expected_field.default
def test_process_null():
schema = {"type": "null"}
expected_type = None
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
def test_process_unknown_type():
schema = {"type": "unknown"}
expected_type = Any
@ -226,8 +237,11 @@ def test_model_caching():
assert len(_model_cache) == 2 # Only two unique models created
def test_multi_type_property():
schema = {"type": ["string", "number"], "description": "A property with multiple types"}
def test_multi_type_property_with_list():
schema = {
"type": ["string", "number"],
"description": "A property with multiple types",
}
expected_field = Field(default=..., description="A property with multiple types")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
@ -243,3 +257,56 @@ def test_multi_type_property():
# Check field properties
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
def test_multi_type_property_with_any_of():
schema = {
"anyOf": [
{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to call",
},
"arguments": {
"type": "string",
"description": "The arguments to pass to the function, as a JSON string",
"default": "{}",
},
},
"required": ["name"],
},
{
"type": "object",
"properties": {
"function_id": {
"type": "int",
"description": "The id of the function to call",
},
},
"required": ["function_id"],
},
{
"type": "string",
"enum": ["auto", "none"],
"description": "Control function calling behavior",
},
],
"description": "A property with multiple types",
}
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
)
# Check if the resulting type is a Union
assert result_type.__origin__ == Union
# Check if the Union has the correct number of types
assert len(result_type.__args__) == 3
assert len(result_type.__args__[0].model_fields) == 2
assert len(result_type.__args__[1].model_fields) == 1
assert result_type.__args__[2] is str
# assert result_field parameter config
assert result_field.description == "A property with multiple types"

View File

@ -3,7 +3,14 @@ from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
from fastapi import HTTPException
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
from mcp.types import (
CallToolResult,
PARSE_ERROR,
INVALID_REQUEST,
METHOD_NOT_FOUND,
INVALID_PARAMS,
INTERNAL_ERROR,
)
from mcp.shared.exceptions import McpError
import json
@ -58,6 +65,21 @@ def _process_schema_property(
default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc)
# Handle the case where prop_type is missing but 'anyOf' key exists
# In this case, use data type from 'anyOf' to determine the type hint
if "anyOf" in prop_schema:
type_hints = []
for i, schema_option in enumerate(prop_schema["anyOf"]):
type_hint, _ = _process_schema_property(
_model_cache,
schema_option,
f"{model_name_prefix}_{prop_name}",
f"choice_{i}",
False,
)
type_hints.append(type_hint)
return Union[tuple(type_hints)], pydantic_field
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
if isinstance(prop_type, list):
# Create a Union of all the types
@ -127,6 +149,8 @@ def _process_schema_property(
return bool, pydantic_field
elif prop_type == "number":
return float, pydantic_field
elif prop_type == "null":
return None, pydantic_field
else:
return Any, pydantic_field
@ -174,7 +198,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e:
@ -222,7 +248,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e: