This commit is contained in:
Timothy Jaeryang Baek 2025-04-06 15:05:00 -07:00
parent 615c333e7f
commit 334089b406

View File

@ -32,6 +32,7 @@ def get_python_type(param_type: str):
return str # Fallback return str # Fallback
# Expand as needed. PRs welcome! # Expand as needed. PRs welcome!
def process_tool_response(result: CallToolResult) -> list: def process_tool_response(result: CallToolResult) -> list:
"""Universal response processor for all tool endpoints""" """Universal response processor for all tool endpoints"""
response = [] response = []
@ -52,6 +53,7 @@ def process_tool_response(result: CallToolResult) -> list:
response.append("Embedded resource not supported yet.") response.append("Embedded resource not supported yet.")
return response return response
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
session = app.state.session session = app.state.session
if not session: if not session:
@ -91,20 +93,29 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
if model_fields: if model_fields:
FormModel = create_model(f"{endpoint_name}_form_model", **model_fields) FormModel = create_model(f"{endpoint_name}_form_model", **model_fields)
def make_endpoint_func(endpoint_name: str, FormModel, session: ClientSession): # Parameterized endpoint def make_endpoint_func(
async def tool_endpoint(form_data: FormModel): endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
args = form_data.model_dump(exclude_none=True) args = form_data.model_dump(exclude_none=True)
result = await session.call_tool(endpoint_name, arguments=args) result = await session.call_tool(endpoint_name, arguments=args)
return process_tool_response(result) return process_tool_response(result)
return tool_endpoint
return tool
tool_handler = make_endpoint_func(endpoint_name, FormModel, session) tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
else: else:
def make_endpoint_func_no_args(endpoint_name: str, session: ClientSession): # Parameterless endpoint
async def tool_endpoint(): # No parameters def make_endpoint_func_no_args(
result = await session.call_tool(endpoint_name, arguments={}) # Empty dict endpoint_name: str, session: ClientSession
): # Parameterless endpoint
async def tool(): # No parameters
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
return process_tool_response(result) # Same processor return process_tool_response(result) # Same processor
return tool_endpoint
return tool
tool_handler = make_endpoint_func_no_args(endpoint_name, session) tool_handler = make_endpoint_func_no_args(endpoint_name, session)
@ -148,11 +159,11 @@ async def lifespan(app: FastAPI):
async def run( async def run(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 8000, port: int = 8000,
api_key: Optional[str] = "", api_key: Optional[str] = "",
cors_allow_origins=["*"], cors_allow_origins=["*"],
**kwargs, **kwargs,
): ):
# Server API Key # Server API Key
api_dependency = get_verify_api_key(api_key) if api_key else None api_dependency = get_verify_api_key(api_key) if api_key else None
@ -162,7 +173,7 @@ async def run(
server_command = kwargs.get("server_command") server_command = kwargs.get("server_command")
name = kwargs.get("name") or "MCP OpenAPI Proxy" name = kwargs.get("name") or "MCP OpenAPI Proxy"
description = ( description = (
kwargs.get("description") or "Automatically generated API from MCP Tool Schemas" kwargs.get("description") or "Automatically generated API from MCP Tool Schemas"
) )
version = kwargs.get("version") or "1.0" version = kwargs.get("version") or "1.0"
ssl_certfile = kwargs.get("ssl_certfile") ssl_certfile = kwargs.get("ssl_certfile")
@ -170,7 +181,12 @@ async def run(
path_prefix = kwargs.get("path_prefix") or "/" path_prefix = kwargs.get("path_prefix") or "/"
main_app = FastAPI( main_app = FastAPI(
title=name, description=description, version=version, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, lifespan=lifespan title=name,
description=description,
version=version,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
lifespan=lifespan,
) )
main_app.add_middleware( main_app.add_middleware(
@ -217,11 +233,20 @@ async def run(
sub_app.state.api_dependency = api_dependency sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app) main_app.mount(f"{path_prefix}{server_name}", sub_app)
main_app.description += f"\n - [{server_name}](http://{host}:{port}/{server_name}/docs)" main_app.description += (
f"\n - [{server_name}](http://{host}:{port}/{server_name}/docs)"
)
else: else:
raise ValueError("You must provide either server_command or config.") raise ValueError("You must provide either server_command or config.")
config = uvicorn.Config(app=main_app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, log_level="info") config = uvicorn.Config(
app=main_app,
host=host,
port=port,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
log_level="info",
)
server = uvicorn.Server(config) server = uvicorn.Server(config)
await server.serve() await server.serve()