Support Custom basemode and output schema

This commit is contained in:
JinY0ung-Shin 2025-04-28 14:29:17 +09:00 committed by GitHub
parent e392df0763
commit 464d886673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,12 +1,19 @@
from typing import Any, Dict, List, Type, Union, ForwardRef
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
from fastapi import HTTPException
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
from mcp.shared.exceptions import McpError
import json
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
from fastapi import HTTPException
from mcp import ClientSession, types
from mcp.shared.exceptions import McpError
from mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
METHOD_NOT_FOUND,
PARSE_ERROR,
CallToolResult,
)
from pydantic import Field, create_model
from pydantic.fields import FieldInfo
MCP_ERROR_TO_HTTP_STATUS = {
PARSE_ERROR: 400,
@ -44,6 +51,7 @@ def _process_schema_property(
model_name_prefix: str,
prop_name: str,
is_required: bool,
custom_fields: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
@ -53,6 +61,12 @@ def _process_schema_property(
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
if "$ref" in prop_schema:
ref = prop_schema["$ref"]
ref = ref.split("/")[-1]
assert ref in custom_fields, "Custom field not found"
prop_schema = custom_fields[ref]
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None)
@ -73,7 +87,12 @@ def _process_schema_property(
for name, schema in nested_properties.items():
is_nested_required = name in nested_required
nested_type_hint, nested_pydantic_field = _process_schema_property(
_model_cache, schema, nested_model_name, name, is_nested_required
_model_cache,
schema,
nested_model_name,
name,
is_nested_required,
custom_fields,
)
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
@ -98,7 +117,8 @@ def _process_schema_property(
items_schema,
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level
False, # Items aren't required at this level,
custom_fields,
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
@ -115,7 +135,7 @@ def _process_schema_property(
return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields):
def get_model_fields(form_model_name, properties, required_fields, custom_fields=None):
model_fields = {}
_model_cache: Dict[str, Type] = {}
@ -123,21 +143,33 @@ def get_model_fields(form_model_name, properties, required_fields):
for param_name, param_schema in properties.items():
is_required = param_name in required_fields
python_type_hint, pydantic_field_info = _process_schema_property(
_model_cache, param_schema, form_model_name, param_name, is_required
_model_cache,
param_schema,
form_model_name,
param_name,
is_required,
custom_fields,
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
return model_fields
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
def get_tool_handler(
session, endpoint_name, form_model_name, model_fields, output_model_fileds=None
):
if model_fields:
FormModel = create_model(form_model_name, **model_fields)
OutputModel = (
create_model(f"{endpoint_name}_output_model", **output_model_fileds)
if output_model_fileds
else Any
)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
async def tool(form_data: FormModel) -> OutputModel:
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
@ -158,7 +190,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e:
@ -206,7 +240,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e: