diff --git a/CHANGELOG.md b/CHANGELOG.md index 9793d7c..d2df7f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on Keep a Changelog, +and this project adheres to Semantic Versioning. + +## [0.0.13] - 2025-05-01 + +### Added + +- 🧪 **Support for Mixed and Union Types (anyOf/nullables)**: mcpo now accurately exposes OpenAPI schemas with anyOf compositions and nullable fields. +- 🧷 **Authentication-Required Docs Access with --strict-auth**: When enabled, the new --strict-auth option restricts access to both the tool endpoints and their interactive documentation pages—ensuring sensitive internal services aren’t inadvertently exposed to unauthenticated users or LLMs. +- 🧬 **Custom Schema Definitions for Complex Models**: Developers can now register custom BaseModel schemas with arbitrary nesting and field variants, allowing precise OpenAPI representations of deeply structured payloads—ensuring crystal-clear docs and compatibility for multi-layered data workflows. +- 🔄 **Smarter Schema Inference Across Data Types**: Schema generation has been enhanced to gracefully handle nested unions, nulls, and fallback types, dramatically improving accuracy in tools using variable output formats or flexible data contracts. + ## [0.0.12] - 2025-04-14 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index b52d4c4..a8d41a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcpo" -version = "0.0.12" +version = "0.0.13" description = "A simple, secure MCP-to-OpenAPI proxy server" authors = [ { name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" } diff --git a/src/mcpo/__init__.py b/src/mcpo/__init__.py index 534937a..4859e20 100644 --- a/src/mcpo/__init__.py +++ b/src/mcpo/__init__.py @@ -26,6 +26,10 @@ def main( Optional[str], typer.Option("--api-key", "-k", help="API key for authentication"), ] = None, + strict_auth: Annotated[ + Optional[bool], + typer.Option("--strict-auth", help="API key protects all endpoints and documentation"), + ] = False, env: Annotated[ Optional[List[str]], typer.Option("--env", "-e", help="Environment variables") ] = None, @@ -116,6 +120,7 @@ def main( host, port, api_key=api_key, + strict_auth=strict_auth, cors_allow_origins=cors_allow_origins, server_type=server_type, config_path=config_path, diff --git a/src/mcpo/main.py b/src/mcpo/main.py index 8fba53e..7339c71 100644 --- a/src/mcpo/main.py +++ b/src/mcpo/main.py @@ -4,17 +4,16 @@ from contextlib import AsyncExitStack, asynccontextmanager from typing import Optional import uvicorn -from fastapi import FastAPI, Depends +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware +from mcp import ClientSession, StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client from starlette.routing import Mount -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client -from mcp.client.sse import sse_client - from mcpo.utils.main import get_model_fields, get_tool_handler -from mcpo.utils.auth import get_verify_api_key +from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): @@ -37,21 +36,31 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): for tool in tools: endpoint_name = tool.name endpoint_description = tool.description - schema = tool.inputSchema - required_fields = schema.get("required", []) - properties = schema.get("properties", {}) + inputSchema = tool.inputSchema + outputSchema = getattr(tool, "outputSchema", None) - form_model_name = f"{endpoint_name}_form_model" - - model_fields = get_model_fields( - form_model_name, - properties, - required_fields, + form_model_fields = get_model_fields( + f"{endpoint_name}_form_model", + inputSchema.get("properties", {}), + inputSchema.get("required", []), + inputSchema.get("$defs", {}), ) + response_model_fields = None + if outputSchema: + response_model_fields = get_model_fields( + f"{endpoint_name}_response_model", + outputSchema.get("properties", {}), + outputSchema.get("required", []), + outputSchema.get("$defs", {}), + ) + tool_handler = get_tool_handler( - session, endpoint_name, form_model_name, model_fields + session, + endpoint_name, + form_model_fields, + response_model_fields, ) app.post( @@ -98,7 +107,10 @@ async def lifespan(app: FastAPI): await create_dynamic_endpoints(app, api_dependency=api_dependency) yield if server_type == "sse": - async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer): + async with sse_client(url=args[0], sse_read_timeout=None) as ( + reader, + writer, + ): async with ClientSession(reader, writer) as session: app.state.session = session await create_dynamic_endpoints(app, api_dependency=api_dependency) @@ -114,6 +126,7 @@ async def run( ): # Server API Key api_dependency = get_verify_api_key(api_key) if api_key else None + strict_auth = kwargs.get("strict_auth", False) # MCP Server server_type = kwargs.get("server_type") # "stdio" or "sse" or "http" @@ -132,7 +145,6 @@ async def run( ssl_certfile = kwargs.get("ssl_certfile") ssl_keyfile = kwargs.get("ssl_keyfile") path_prefix = kwargs.get("path_prefix") or "/" - main_app = FastAPI( title=name, description=description, @@ -150,6 +162,10 @@ async def run( allow_headers=["*"], ) + # Add middleware to protect also documentation and spec + if api_key and strict_auth: + main_app.add_middleware(APIKeyMiddleware, api_key=api_key) + if server_type == "sse": main_app.state.server_type = "sse" main_app.state.args = server_command[0] @@ -197,6 +213,10 @@ async def run( sub_app.state.server_type = "sse" sub_app.state.args = server_cfg["url"] + # Add middleware to protect also documentation and spec + if api_key and strict_auth: + sub_app.add_middleware(APIKeyMiddleware, api_key=api_key) + sub_app.state.api_dependency = api_dependency main_app.mount(f"{path_prefix}{server_name}", sub_app) diff --git a/src/mcpo/tests/test_main.py b/src/mcpo/tests/test_main.py index db10617..a7cfe5f 100644 --- a/src/mcpo/tests/test_main.py +++ b/src/mcpo/tests/test_main.py @@ -1,6 +1,6 @@ import pytest from pydantic import BaseModel, Field -from typing import Any, List, Dict +from typing import Any, List, Dict, Union from mcpo.utils.main import _process_schema_property @@ -64,6 +64,17 @@ def test_process_simple_number(): assert result_field.default == expected_field.default +def test_process_null(): + schema = {"type": "null"} + expected_type = None + expected_field = Field(default=..., description="") + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "prop", True + ) + assert result_type == expected_type + assert result_field.default == expected_field.default + + def test_process_unknown_type(): schema = {"type": "unknown"} expected_type = Any @@ -224,3 +235,78 @@ def test_model_caching(): ) assert result_type3 == result_type1 # Should be the same cached object assert len(_model_cache) == 2 # Only two unique models created + + +def test_multi_type_property_with_list(): + schema = { + "type": ["string", "number"], + "description": "A property with multiple types", + } + expected_field = Field(default=..., description="A property with multiple types") + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "multi_type", True + ) + + # Check if the resulting type is a Union + assert str(result_type).startswith("typing.Union[") + + # Check if both types are in the Union + assert str in result_type.__args__ + assert float in result_type.__args__ + + # Check field properties + assert result_field.default == expected_field.default + assert result_field.description == expected_field.description + + +def test_multi_type_property_with_any_of(): + schema = { + "anyOf": [ + { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the function to call", + }, + "arguments": { + "type": "string", + "description": "The arguments to pass to the function, as a JSON string", + "default": "{}", + }, + }, + "required": ["name"], + }, + { + "type": "object", + "properties": { + "function_id": { + "type": "int", + "description": "The id of the function to call", + }, + }, + "required": ["function_id"], + }, + { + "type": "string", + "enum": ["auto", "none"], + "description": "Control function calling behavior", + }, + ], + "description": "A property with multiple types", + } + result_type, result_field = _process_schema_property( + _model_cache, schema, "test", "multi_type", True + ) + + # Check if the resulting type is a Union + assert result_type.__origin__ == Union + + # Check if the Union has the correct number of types + assert len(result_type.__args__) == 3 + assert len(result_type.__args__[0].model_fields) == 2 + assert len(result_type.__args__[1].model_fields) == 1 + assert result_type.__args__[2] is str + + # assert result_field parameter config + assert result_field.description == "A property with multiple types" diff --git a/src/mcpo/utils/auth.py b/src/mcpo/utils/auth.py index d32a28e..ecdfe97 100644 --- a/src/mcpo/utils/auth.py +++ b/src/mcpo/utils/auth.py @@ -1,5 +1,8 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from fastapi import Depends, Header, HTTPException, status +from fastapi import Depends, Header, HTTPException, Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +import base64 from passlib.context import CryptContext from datetime import UTC, datetime, timedelta @@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str): return verify_api_key +class APIKeyMiddleware(BaseHTTPMiddleware): + """ + Middleware that enforces Basic or Bearer token authentication for all requests. + """ + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next): + # Skip authentication for OPTIONS requests + if request.method == "OPTIONS": + return await call_next(request) + + # Get authorization header + authorization = request.headers.get("Authorization") + + # Verify API key + try: + # Use the same function that the dependency uses + if not authorization: + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid Authorization header"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + + # Handle Bearer token auth + if authorization.startswith("Bearer "): + token = authorization[7:] # Remove "Bearer " prefix + if token != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid API key"} + ) + # Handle Basic auth + elif authorization.startswith("Basic "): + # Decode the base64 credentials + credentials = authorization[6:] # Remove "Basic " prefix + try: + decoded = base64.b64decode(credentials).decode('utf-8') + # Basic auth format is username:password + username, password = decoded.split(':', 1) + # Any username is allowed, but password must match api_key + if password != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid credentials"} + ) + except Exception: + return JSONResponse( + status_code=401, + content={"detail": "Invalid Basic Authentication format"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + else: + return JSONResponse( + status_code=401, + content={"detail": "Unsupported authorization method"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + + return await call_next(request) + except Exception as e: + return JSONResponse( + status_code=500, + content={"detail": str(e)} + ) + + # def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: # payload = data.copy() diff --git a/src/mcpo/utils/main.py b/src/mcpo/utils/main.py index 29a71dd..4d72815 100644 --- a/src/mcpo/utils/main.py +++ b/src/mcpo/utils/main.py @@ -1,12 +1,22 @@ -from typing import Any, Dict, List, Type, Union, ForwardRef -from pydantic import create_model, Field -from pydantic.fields import FieldInfo -from mcp import ClientSession, types +import json +from typing import Any, Dict, ForwardRef, List, Optional, Type, Union + from fastapi import HTTPException -from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR + +from mcp import ClientSession, types +from mcp.types import ( + CallToolResult, + PARSE_ERROR, + INVALID_REQUEST, + METHOD_NOT_FOUND, + INVALID_PARAMS, + INTERNAL_ERROR, +) + from mcp.shared.exceptions import McpError -import json +from pydantic import Field, create_model +from pydantic.fields import FieldInfo MCP_ERROR_TO_HTTP_STATUS = { PARSE_ERROR: 400, @@ -44,6 +54,7 @@ def _process_schema_property( model_name_prefix: str, prop_name: str, is_required: bool, + schema_defs: Optional[Dict] = None, ) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: """ Recursively processes a schema property to determine its Python type hint @@ -53,11 +64,49 @@ def _process_schema_property( A tuple containing (python_type_hint, pydantic_field). The pydantic_field contains default value and description. """ + if "$ref" in prop_schema: + ref = prop_schema["$ref"] + ref = ref.split("/")[-1] + assert ref in schema_defs, "Custom field not found" + prop_schema = schema_defs[ref] + prop_type = prop_schema.get("type") prop_desc = prop_schema.get("description", "") + default_value = ... if is_required else prop_schema.get("default", None) pydantic_field = Field(default=default_value, description=prop_desc) + # Handle the case where prop_type is missing but 'anyOf' key exists + # In this case, use data type from 'anyOf' to determine the type hint + if "anyOf" in prop_schema: + type_hints = [] + for i, schema_option in enumerate(prop_schema["anyOf"]): + type_hint, _ = _process_schema_property( + _model_cache, + schema_option, + f"{model_name_prefix}_{prop_name}", + f"choice_{i}", + False, + ) + type_hints.append(type_hint) + return Union[tuple(type_hints)], pydantic_field + + # Handle the case where prop_type is a list of types, e.g. ['string', 'number'] + if isinstance(prop_type, list): + # Create a Union of all the types + type_hints = [] + for type_option in prop_type: + # Create a temporary schema with the single type and process it + temp_schema = dict(prop_schema) + temp_schema["type"] = type_option + type_hint, _ = _process_schema_property( + _model_cache, temp_schema, model_name_prefix, prop_name, False + ) + type_hints.append(type_hint) + + # Return a Union of all possible types + return Union[tuple(type_hints)], pydantic_field + if prop_type == "object": nested_properties = prop_schema.get("properties", {}) nested_required = prop_schema.get("required", []) @@ -73,7 +122,12 @@ def _process_schema_property( for name, schema in nested_properties.items(): is_nested_required = name in nested_required nested_type_hint, nested_pydantic_field = _process_schema_property( - _model_cache, schema, nested_model_name, name, is_nested_required + _model_cache, + schema, + nested_model_name, + name, + is_nested_required, + schema_defs, ) nested_fields[name] = (nested_type_hint, nested_pydantic_field) @@ -98,7 +152,8 @@ def _process_schema_property( items_schema, f"{model_name_prefix}_{prop_name}", "item", - False, # Items aren't required at this level + False, # Items aren't required at this level, + schema_defs, ) list_type_hint = List[item_type_hint] return list_type_hint, pydantic_field @@ -111,11 +166,13 @@ def _process_schema_property( return bool, pydantic_field elif prop_type == "number": return float, pydantic_field + elif prop_type == "null": + return None, pydantic_field else: return Any, pydantic_field -def get_model_fields(form_model_name, properties, required_fields): +def get_model_fields(form_model_name, properties, required_fields, schema_defs=None): model_fields = {} _model_cache: Dict[str, Type] = {} @@ -123,21 +180,36 @@ def get_model_fields(form_model_name, properties, required_fields): for param_name, param_schema in properties.items(): is_required = param_name in required_fields python_type_hint, pydantic_field_info = _process_schema_property( - _model_cache, param_schema, form_model_name, param_name, is_required + _model_cache, + param_schema, + form_model_name, + param_name, + is_required, + schema_defs, ) # Use the generated type hint and Field info model_fields[param_name] = (python_type_hint, pydantic_field_info) return model_fields -def get_tool_handler(session, endpoint_name, form_model_name, model_fields): - if model_fields: - FormModel = create_model(form_model_name, **model_fields) +def get_tool_handler( + session, + endpoint_name, + form_model_fields, + response_model_fields=None, +): + if form_model_fields: + FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields) + ResponseModel = ( + create_model(f"{endpoint_name}_response_model", **response_model_fields) + if response_model_fields + else Any + ) def make_endpoint_func( endpoint_name: str, FormModel, session: ClientSession ): # Parameterized endpoint - async def tool(form_data: FormModel): + async def tool(form_data: FormModel) -> ResponseModel: args = form_data.model_dump(exclude_none=True) print(f"Calling endpoint: {endpoint_name}, with args: {args}") try: @@ -158,7 +230,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): ) response_data = process_tool_response(result) - final_response = response_data[0] if len(response_data) == 1 else response_data + final_response = ( + response_data[0] if len(response_data) == 1 else response_data + ) return final_response except McpError as e: @@ -206,7 +280,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): ) response_data = process_tool_response(result) - final_response = response_data[0] if len(response_data) == 1 else response_data + final_response = ( + response_data[0] if len(response_data) == 1 else response_data + ) return final_response except McpError as e: