From 334089b4065d48a6c6998104ededf777637e54aa Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 6 Apr 2025 15:05:00 -0700 Subject: [PATCH] refac --- src/mcpo/main.py | 57 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/src/mcpo/main.py b/src/mcpo/main.py index a9812f7..9d7131b 100644 --- a/src/mcpo/main.py +++ b/src/mcpo/main.py @@ -32,6 +32,7 @@ def get_python_type(param_type: str): return str # Fallback # Expand as needed. PRs welcome! + def process_tool_response(result: CallToolResult) -> list: """Universal response processor for all tool endpoints""" response = [] @@ -52,6 +53,7 @@ def process_tool_response(result: CallToolResult) -> list: response.append("Embedded resource not supported yet.") return response + async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): session = app.state.session if not session: @@ -91,20 +93,29 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): if model_fields: FormModel = create_model(f"{endpoint_name}_form_model", **model_fields) - def make_endpoint_func(endpoint_name: str, FormModel, session: ClientSession): # Parameterized endpoint - async def tool_endpoint(form_data: FormModel): + def make_endpoint_func( + endpoint_name: str, FormModel, session: ClientSession + ): # Parameterized endpoint + async def tool(form_data: FormModel): args = form_data.model_dump(exclude_none=True) result = await session.call_tool(endpoint_name, arguments=args) return process_tool_response(result) - return tool_endpoint + + return tool tool_handler = make_endpoint_func(endpoint_name, FormModel, session) else: - def make_endpoint_func_no_args(endpoint_name: str, session: ClientSession): # Parameterless endpoint - async def tool_endpoint(): # No parameters - result = await session.call_tool(endpoint_name, arguments={}) # Empty dict + + def make_endpoint_func_no_args( + endpoint_name: str, session: ClientSession + ): # Parameterless endpoint + async def tool(): # No parameters + result = await session.call_tool( + endpoint_name, arguments={} + ) # Empty dict return process_tool_response(result) # Same processor - return tool_endpoint + + return tool tool_handler = make_endpoint_func_no_args(endpoint_name, session) @@ -148,11 +159,11 @@ async def lifespan(app: FastAPI): async def run( - host: str = "127.0.0.1", - port: int = 8000, - api_key: Optional[str] = "", - cors_allow_origins=["*"], - **kwargs, + host: str = "127.0.0.1", + port: int = 8000, + api_key: Optional[str] = "", + cors_allow_origins=["*"], + **kwargs, ): # Server API Key api_dependency = get_verify_api_key(api_key) if api_key else None @@ -162,7 +173,7 @@ async def run( server_command = kwargs.get("server_command") name = kwargs.get("name") or "MCP OpenAPI Proxy" description = ( - kwargs.get("description") or "Automatically generated API from MCP Tool Schemas" + kwargs.get("description") or "Automatically generated API from MCP Tool Schemas" ) version = kwargs.get("version") or "1.0" ssl_certfile = kwargs.get("ssl_certfile") @@ -170,7 +181,12 @@ async def run( path_prefix = kwargs.get("path_prefix") or "/" main_app = FastAPI( - title=name, description=description, version=version, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, lifespan=lifespan + title=name, + description=description, + version=version, + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + lifespan=lifespan, ) main_app.add_middleware( @@ -217,11 +233,20 @@ async def run( sub_app.state.api_dependency = api_dependency main_app.mount(f"{path_prefix}{server_name}", sub_app) - main_app.description += f"\n - [{server_name}](http://{host}:{port}/{server_name}/docs)" + main_app.description += ( + f"\n - [{server_name}](http://{host}:{port}/{server_name}/docs)" + ) else: raise ValueError("You must provide either server_command or config.") - config = uvicorn.Config(app=main_app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, log_level="info") + config = uvicorn.Config( + app=main_app, + host=host, + port=port, + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + log_level="info", + ) server = uvicorn.Server(config) await server.serve()