Merge pull request #105 from JinY0ung-Shin/main

feat: support cusotm BaseModel "$def" and Output schema
This commit is contained in:
Tim Jaeryang Baek 2025-04-30 02:48:08 -07:00 committed by GitHub
commit 1566a34c9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 24 deletions

View File

@ -4,15 +4,14 @@ from contextlib import AsyncExitStack, asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from starlette.routing import Mount
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
@ -37,21 +36,32 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
for tool in tools:
endpoint_name = tool.name
endpoint_description = tool.description
schema = tool.inputSchema
required_fields = schema.get("required", [])
properties = schema.get("properties", {})
inputSchema = tool.inputSchema
outputSchema = getattr(tool, "outputSchema", None)
custom_fileds = inputSchema.get("$defs", {})
required_fields = inputSchema.get("required", [])
properties = inputSchema.get("properties", {})
form_model_name = f"{endpoint_name}_form_model"
model_fields = get_model_fields(
form_model_name,
properties,
required_fields,
custom_fileds,
)
if outputSchema:
output_model_name = f"{endpoint_name}_output_model"
output_model_fields = get_model_fields(
output_model_name,
outputSchema.get("properties", {}),
outputSchema.get("required", []),
outputSchema.get("$defs", {}),
)
else:
output_model_fields = None
tool_handler = get_tool_handler(
session, endpoint_name, form_model_name, model_fields
session, endpoint_name, form_model_name, model_fields, output_model_fields
)
app.post(
@ -98,7 +108,10 @@ async def lifespan(app: FastAPI):
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
if server_type == "sse":
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
async with sse_client(url=args[0], sse_read_timeout=None) as (
reader,
writer,
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
@ -133,7 +146,6 @@ async def run(
ssl_certfile = kwargs.get("ssl_certfile")
ssl_keyfile = kwargs.get("ssl_keyfile")
path_prefix = kwargs.get("path_prefix") or "/"
main_app = FastAPI(
title=name,
description=description,

View File

@ -1,8 +1,9 @@
from typing import Any, Dict, List, Type, Union, ForwardRef
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
import json
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
from fastapi import HTTPException
from mcp import ClientSession, types
from mcp.types import (
CallToolResult,
PARSE_ERROR,
@ -11,9 +12,11 @@ from mcp.types import (
INVALID_PARAMS,
INTERNAL_ERROR,
)
from mcp.shared.exceptions import McpError
import json
from pydantic import Field, create_model
from pydantic.fields import FieldInfo
MCP_ERROR_TO_HTTP_STATUS = {
PARSE_ERROR: 400,
@ -51,6 +54,7 @@ def _process_schema_property(
model_name_prefix: str,
prop_name: str,
is_required: bool,
custom_fields: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
@ -60,6 +64,12 @@ def _process_schema_property(
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
if "$ref" in prop_schema:
ref = prop_schema["$ref"]
ref = ref.split("/")[-1]
assert ref in custom_fields, "Custom field not found"
prop_schema = custom_fields[ref]
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None)
@ -111,7 +121,12 @@ def _process_schema_property(
for name, schema in nested_properties.items():
is_nested_required = name in nested_required
nested_type_hint, nested_pydantic_field = _process_schema_property(
_model_cache, schema, nested_model_name, name, is_nested_required
_model_cache,
schema,
nested_model_name,
name,
is_nested_required,
custom_fields,
)
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
@ -136,7 +151,8 @@ def _process_schema_property(
items_schema,
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level
False, # Items aren't required at this level,
custom_fields,
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
@ -155,7 +171,7 @@ def _process_schema_property(
return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields):
def get_model_fields(form_model_name, properties, required_fields, custom_fields=None):
model_fields = {}
_model_cache: Dict[str, Type] = {}
@ -163,21 +179,33 @@ def get_model_fields(form_model_name, properties, required_fields):
for param_name, param_schema in properties.items():
is_required = param_name in required_fields
python_type_hint, pydantic_field_info = _process_schema_property(
_model_cache, param_schema, form_model_name, param_name, is_required
_model_cache,
param_schema,
form_model_name,
param_name,
is_required,
custom_fields,
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
return model_fields
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
def get_tool_handler(
session, endpoint_name, form_model_name, model_fields, output_model_fileds=None
):
if model_fields:
FormModel = create_model(form_model_name, **model_fields)
OutputModel = (
create_model(f"{endpoint_name}_output_model", **output_model_fileds)
if output_model_fileds
else Any
)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
async def tool(form_data: FormModel) -> OutputModel:
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
try: