This commit is contained in:
bmen25124 2025-06-07 01:03:09 +08:00 committed by GitHub
commit 4de89ed0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 14 deletions

View File

@ -53,7 +53,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
f"{endpoint_name}_form_model",
inputSchema.get("properties", {}),
inputSchema.get("required", []),
inputSchema.get("$defs", {}),
schema_defs=inputSchema.get("$defs", {}),
root_schema=inputSchema,
)
response_model_fields = None
@ -62,7 +63,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
f"{endpoint_name}_response_model",
outputSchema.get("properties", {}),
outputSchema.get("required", []),
outputSchema.get("$defs", {}),
schema_defs=outputSchema.get("$defs", {}),
root_schema=outputSchema,
)
tool_handler = get_tool_handler(

View File

@ -1,11 +1,11 @@
import pytest
from pydantic import BaseModel, Field
from typing import Any, List, Dict, Union
from typing import Any, List, Dict, Type, Union
from mcpo.utils.main import _process_schema_property
_model_cache = {}
_model_cache: Dict[str, Type] = {}
@pytest.fixture(autouse=True)
@ -310,3 +310,96 @@ def test_multi_type_property_with_any_of():
# assert result_field parameter config
assert result_field.description == "A property with multiple types"
def test_process_property_reference():
schema = {
"type": "object",
"properties": {
"start_time": {
"type": "string",
"format": "date-time",
"description": "Start time in ISO 8601 format",
},
"end_time": {
"$ref": "#/properties/start_time",
"description": "End time in ISO 8601 format",
},
},
"required": ["start_time"],
}
# First process the start_time property to ensure reference target exists
result_type, result_field = _process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)
assert issubclass(result_type, BaseModel)
model_fields = result_type.model_fields
# Check that both fields have the same type (string)
assert model_fields["start_time"].annotation is str
assert model_fields["end_time"].annotation is str
# Check descriptions are preserved
assert model_fields["start_time"].description == "Start time in ISO 8601 format"
assert model_fields["end_time"].description == "End time in ISO 8601 format"
def test_process_invalid_property_reference():
schema = {
"type": "object",
"properties": {"invalid_ref": {"$ref": "#/properties/nonexistent"}},
}
with pytest.raises(
ValueError, match="Reference not found: #/properties/nonexistent"
):
_process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)
def test_process_nested_property_reference():
schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"created_at": {"type": "string", "format": "date-time"},
"updated_at": {"$ref": "#/properties/user/properties/created_at"},
},
}
},
}
result_type, _ = _process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)
assert issubclass(result_type, BaseModel)
user_field = result_type.model_fields["user"]
user_model = user_field.annotation
# Both timestamps should be strings
assert user_model.model_fields["created_at"].annotation is str
assert user_model.model_fields["updated_at"].annotation is str

View File

@ -81,8 +81,9 @@ def _process_schema_property(
model_name_prefix: str,
prop_name: str,
is_required: bool,
schema_defs: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
schema_defs: Optional[Dict[str, Any]] = None,
root_schema: Optional[Dict[str, Any]] = None,
) -> tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
and Pydantic Field definition.
@ -91,11 +92,34 @@ def _process_schema_property(
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
original_schema = prop_schema.copy()
if "$ref" in prop_schema:
ref = prop_schema["$ref"]
ref = ref.split("/")[-1]
assert ref in schema_defs, "Custom field not found"
prop_schema = schema_defs[ref]
ref_parts = ref.split("/")[1:] # Skip the '#' at the start
# Start from the root schema
current: Optional[Dict[str, Any]] = None
if ref_parts[0] in ["definitions", "$defs"] and schema_defs is not None:
current = schema_defs
elif ref_parts[0] == "properties" and root_schema is not None:
current = root_schema.get("properties", {})
if current is None:
raise ValueError(f"Cannot resolve reference: {ref}")
# Navigate through the reference path
for part in ref_parts[1:]: # Skip the first part since we already used it
if not isinstance(current, dict):
raise ValueError(f"Invalid reference path: {ref}")
current = current.get(part)
if current is None:
raise ValueError(f"Reference not found: {ref}")
# Merge referenced schema while preserving local overrides
prop_schema = {
**current,
**{k: v for k, v in original_schema.items() if k != "$ref"},
}
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
@ -114,6 +138,8 @@ def _process_schema_property(
f"{model_name_prefix}_{prop_name}",
f"choice_{i}",
False,
schema_defs=schema_defs,
root_schema=root_schema,
)
type_hints.append(type_hint)
return Union[tuple(type_hints)], pydantic_field
@ -127,7 +153,13 @@ def _process_schema_property(
temp_schema = dict(prop_schema)
temp_schema["type"] = type_option
type_hint, _ = _process_schema_property(
_model_cache, temp_schema, model_name_prefix, prop_name, False
_model_cache,
temp_schema,
model_name_prefix,
prop_name,
False,
schema_defs=schema_defs,
root_schema=root_schema,
)
type_hints.append(type_hint)
@ -154,7 +186,8 @@ def _process_schema_property(
nested_model_name,
name,
is_nested_required,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)
if name_needs_alias(name):
@ -190,7 +223,8 @@ def _process_schema_property(
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
@ -209,7 +243,13 @@ def _process_schema_property(
return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
def get_model_fields(
form_model_name: str,
properties: Dict[str, Any],
required_fields: List[str],
schema_defs: Optional[Dict[str, Any]] = None,
root_schema: Optional[Dict[str, Any]] = None,
) -> Dict[str, tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]]:
model_fields = {}
_model_cache: Dict[str, Type] = {}
@ -222,7 +262,8 @@ def get_model_fields(form_model_name, properties, required_fields, schema_defs=N
form_model_name,
param_name,
is_required,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)
# Handle parameter names with leading underscores (e.g., __top, __filter) which Pydantic v2 does not allow