from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware from pydantic import create_model from mcp import ClientSession, StdioServerParameters, types from mcp.client.stdio import stdio_client import argparse import sys from typing import Dict, Any import asyncio import uvicorn import json import os async def create_dynamic_endpoints(app: FastAPI, session: ClientSession): tools_result = await session.list_tools() tools = tools_result.tools for tool in tools: print(tool) endpoint_name = tool.name endpoint_description = tool.description schema = tool.inputSchema # Dynamically creating a Pydantic model for validation and openAPI coverage model_fields = {} required_fields = schema.get("required", []) for param_name, param_schema in schema["properties"].items(): param_type = param_schema["type"] param_desc = param_schema.get("description", "") python_type = str # default if param_type == "string": python_type = str elif param_type == "integer": python_type = int elif param_type == "boolean": python_type = bool elif param_type == "number": python_type = float elif param_type == "object": python_type = Dict[str, Any] elif param_type == "array": python_type = list # Expand as needed. PRs welcome! default_value = ... if param_name in required_fields else None model_fields[param_name] = ( python_type, Body(default_value, description=param_desc), ) FormModel = create_model(f"{endpoint_name}_form_model", **model_fields) def make_endpoint_func(endpoint_name: str, FormModel): async def endpoint_func(form_data: FormModel): args = form_data.model_dump() print("Calling tool with arguments:", args) print("Tool name:", endpoint_name) result = await session.call_tool(endpoint_name, arguments=args) return result return endpoint_func endpoint_func = make_endpoint_func(endpoint_name, FormModel) # Add endpoint to FastAPI with tool descriptions app.post( f"/{endpoint_name}", summary=endpoint_name.replace("_", " ").title(), description=endpoint_description, )(endpoint_func) async def run(host: str, port: int, server_cmd: list[str]): server_params = StdioServerParameters( command=server_cmd[0], args=server_cmd[1:], env={**os.environ}, ) # Open connection to MCP first: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: result = await session.initialize() server_name = ( result.serverInfo.name if hasattr(result, "serverInfo") and hasattr(result.serverInfo, "name") else None ) server_description = ( f"{server_name.capitalize()} MCP OpenAPI Proxy" if server_name else "Automatically generated API endpoints based on MCP tool schemas." ) server_version = ( result.serverInfo.version if hasattr(result, "serverInfo") and hasattr(result.serverInfo, "version") else "1.0" ) app = FastAPI( title=server_name if server_name else "MCP OpenAPI Proxy", description=server_description, version=server_version, ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Dynamic endpoint creation await create_dynamic_endpoints(app, session) config = uvicorn.Config(app=app, host=host, port=port, log_level="info") server = uvicorn.Server(config) await server.serve() def parse_args(): # Separate user args before and after "--" if "--" not in sys.argv: print("Usage: python main.py --host 0.0.0.0 --port 8000 -- your_mcp_command") sys.exit(1) split_index = sys.argv.index("--") proxy_args = sys.argv[1:split_index] mcp_args = sys.argv[split_index + 1 :] parser = argparse.ArgumentParser(description="FastAPI MCP OpenAPI Proxy") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on") parser.add_argument("--port", type=int, default=8000, help="Port to listen on") args = parser.parse_args(proxy_args) if not mcp_args: print("Error: You must specify the MCP server command after '--'") sys.exit(1) return args.host, args.port, mcp_args if __name__ == "__main__": host, port, server_cmd = parse_args() asyncio.run(run(host, port, server_cmd))