This commit is contained in:
Timothy Jaeryang Baek 2025-04-10 09:38:08 -07:00
parent 6e05f1c56f
commit 3941420630
2 changed files with 172 additions and 137 deletions

View File

@ -6,111 +6,19 @@ from typing import Dict, Any, Optional, List, Type, Union, ForwardRef
import uvicorn
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
from mcp.types import CallToolResult
from mcpo.utils.auth import get_verify_api_key
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from starlette.routing import Mount
_model_cache: Dict[str, Type] = {}
def _process_schema_property(
prop_schema: Dict[str, Any],
model_name_prefix: str,
prop_name: str,
is_required: bool,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
and Pydantic Field definition.
Returns:
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc)
if prop_type == "object":
nested_properties = prop_schema.get("properties", {})
nested_required = prop_schema.get("required", [])
nested_fields = {}
nested_model_name = f"{model_name_prefix}_{prop_name}_model".replace("__", "_").rstrip('_')
if nested_model_name in _model_cache:
return _model_cache[nested_model_name], pydantic_field
for name, schema in nested_properties.items():
is_nested_required = name in nested_required
nested_type_hint, nested_pydantic_field = _process_schema_property(
schema, nested_model_name, name, is_nested_required
)
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
if not nested_fields:
return Dict[str, Any], pydantic_field
NestedModel = create_model(nested_model_name, **nested_fields)
_model_cache[nested_model_name] = NestedModel
return NestedModel, pydantic_field
elif prop_type == "array":
items_schema = prop_schema.get("items")
if not items_schema:
# Default to list of anything if items schema is missing
return List[Any], pydantic_field
# Recursively determine the type of items in the array
item_type_hint, _ = _process_schema_property(
items_schema, f"{model_name_prefix}_{prop_name}", "item", False # Items aren't required at this level
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
elif prop_type == "string":
return str, pydantic_field
elif prop_type == "integer":
return int, pydantic_field
elif prop_type == "boolean":
return bool, pydantic_field
elif prop_type == "number":
return float, pydantic_field
else:
return Any, pydantic_field
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
def process_tool_response(result: CallToolResult) -> list:
"""Universal response processor for all tool endpoints"""
response = []
for content in result.content:
if isinstance(content, types.TextContent):
text = content.text
if isinstance(text, str):
try:
text = json.loads(text)
except json.JSONDecodeError:
pass
response.append(text)
elif isinstance(content, types.ImageContent):
image_data = f"data:{content.mimeType};base64,{content.data}"
response.append(image_data)
elif isinstance(content, types.EmbeddedResource):
# TODO: Handle embedded resources
response.append("Embedded resource not supported yet.")
return response
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
session = app.state.session
session: ClientSession = app.state.session
if not session:
raise ValueError("Session is not initialized in the app state.")
@ -131,51 +39,20 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
endpoint_description = tool.description
schema = tool.inputSchema
model_fields = {}
required_fields = schema.get("required", [])
properties = schema.get("properties", {})
form_model_name = f"{endpoint_name}_form_model"
for param_name, param_schema in properties.items():
is_required = param_name in required_fields
python_type_hint, pydantic_field_info = _process_schema_property(
param_schema, form_model_name, param_name, is_required
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
if model_fields:
FormModel = create_model(form_model_name, **model_fields)
model_fields = get_model_fields(
form_model_name,
properties,
required_fields,
)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
result = await session.call_tool(endpoint_name, arguments=args)
return process_tool_response(result)
return tool
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
else:
def make_endpoint_func_no_args(
endpoint_name: str, session: ClientSession
): # Parameterless endpoint
async def tool(): # No parameters
print(f"Calling endpoint: {endpoint_name}, with no args")
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
return process_tool_response(result) # Same processor
return tool
tool_handler = make_endpoint_func_no_args(endpoint_name, session)
tool_handler = get_tool_handler(
session, endpoint_name, form_model_name, model_fields
)
app.post(
f"/{endpoint_name}",
@ -229,11 +106,13 @@ async def run(
# MCP Config
config_path = kwargs.get("config")
server_command = kwargs.get("server_command")
name = kwargs.get("name") or "MCP OpenAPI Proxy"
description = (
kwargs.get("description") or "Automatically generated API from MCP Tool Schemas"
)
version = kwargs.get("version") or "1.0"
ssl_certfile = kwargs.get("ssl_certfile")
ssl_keyfile = kwargs.get("ssl_keyfile")
path_prefix = kwargs.get("path_prefix") or "/"
@ -256,7 +135,6 @@ async def run(
)
if server_command:
main_app.state.command = server_command[0]
main_app.state.args = server_command[1:]
main_app.state.env = os.environ.copy()
@ -265,9 +143,11 @@ async def run(
elif config_path:
with open(config_path, "r") as f:
config_data = json.load(f)
mcp_servers = config_data.get("mcpServers", {})
if not mcp_servers:
raise ValueError("No 'mcpServers' found in config file.")
main_app.description += "\n\n- **available tools**"
for server_name, server_cfg in mcp_servers.items():
sub_app = FastAPI(
@ -290,6 +170,7 @@ async def run(
sub_app.state.env = {**os.environ, **server_cfg.get("env", {})}
sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app)
main_app.description += f"\n - [{server_name}](/{server_name}/docs)"
else:

154
src/mcpo/utils/main.py Normal file
View File

@ -0,0 +1,154 @@
from typing import Any, Dict, List, Type, Union, ForwardRef
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
from mcp.types import CallToolResult
import json
def process_tool_response(result: CallToolResult) -> list:
"""Universal response processor for all tool endpoints"""
response = []
for content in result.content:
if isinstance(content, types.TextContent):
text = content.text
if isinstance(text, str):
try:
text = json.loads(text)
except json.JSONDecodeError:
pass
response.append(text)
elif isinstance(content, types.ImageContent):
image_data = f"data:{content.mimeType};base64,{content.data}"
response.append(image_data)
elif isinstance(content, types.EmbeddedResource):
# TODO: Handle embedded resources
response.append("Embedded resource not supported yet.")
return response
def _process_schema_property(
_model_cache: Dict[str, Type],
prop_schema: Dict[str, Any],
model_name_prefix: str,
prop_name: str,
is_required: bool,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
and Pydantic Field definition.
Returns:
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc)
if prop_type == "object":
nested_properties = prop_schema.get("properties", {})
nested_required = prop_schema.get("required", [])
nested_fields = {}
nested_model_name = f"{model_name_prefix}_{prop_name}_model".replace(
"__", "_"
).rstrip("_")
if nested_model_name in _model_cache:
return _model_cache[nested_model_name], pydantic_field
for name, schema in nested_properties.items():
is_nested_required = name in nested_required
nested_type_hint, nested_pydantic_field = _process_schema_property(
schema, nested_model_name, name, is_nested_required
)
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
if not nested_fields:
return Dict[str, Any], pydantic_field
NestedModel = create_model(nested_model_name, **nested_fields)
_model_cache[nested_model_name] = NestedModel
return NestedModel, pydantic_field
elif prop_type == "array":
items_schema = prop_schema.get("items")
if not items_schema:
# Default to list of anything if items schema is missing
return List[Any], pydantic_field
# Recursively determine the type of items in the array
item_type_hint, _ = _process_schema_property(
items_schema,
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
elif prop_type == "string":
return str, pydantic_field
elif prop_type == "integer":
return int, pydantic_field
elif prop_type == "boolean":
return bool, pydantic_field
elif prop_type == "number":
return float, pydantic_field
else:
return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields):
model_fields = {}
_model_cache: Dict[str, Type] = {}
for param_name, param_schema in properties.items():
is_required = param_name in required_fields
python_type_hint, pydantic_field_info = _process_schema_property(
_model_cache, param_schema, form_model_name, param_name, is_required
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
return model_fields
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
if model_fields:
FormModel = create_model(form_model_name, **model_fields)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
result = await session.call_tool(endpoint_name, arguments=args)
return process_tool_response(result)
return tool
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
else:
def make_endpoint_func_no_args(
endpoint_name: str, session: ClientSession
): # Parameterless endpoint
async def tool(): # No parameters
print(f"Calling endpoint: {endpoint_name}, with no args")
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
return process_tool_response(result) # Same processor
return tool
tool_handler = make_endpoint_func_no_args(endpoint_name, session)