From 9d3e34bfa9e92042af1725ad9b2c90ebdc062654 Mon Sep 17 00:00:00 2001 From: Thomas Rehn <271119+tremlin@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:23:52 +0200 Subject: [PATCH] feat: implement option to make authentication required to access documentation and spec --- src/mcpo/__init__.py | 5 +++ src/mcpo/main.py | 11 ++++++- src/mcpo/utils/auth.py | 74 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/mcpo/__init__.py b/src/mcpo/__init__.py index 534937a..4859e20 100644 --- a/src/mcpo/__init__.py +++ b/src/mcpo/__init__.py @@ -26,6 +26,10 @@ def main( Optional[str], typer.Option("--api-key", "-k", help="API key for authentication"), ] = None, + strict_auth: Annotated[ + Optional[bool], + typer.Option("--strict-auth", help="API key protects all endpoints and documentation"), + ] = False, env: Annotated[ Optional[List[str]], typer.Option("--env", "-e", help="Environment variables") ] = None, @@ -116,6 +120,7 @@ def main( host, port, api_key=api_key, + strict_auth=strict_auth, cors_allow_origins=cors_allow_origins, server_type=server_type, config_path=config_path, diff --git a/src/mcpo/main.py b/src/mcpo/main.py index 8fba53e..0d95840 100644 --- a/src/mcpo/main.py +++ b/src/mcpo/main.py @@ -14,7 +14,7 @@ from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client from mcpo.utils.main import get_model_fields, get_tool_handler -from mcpo.utils.auth import get_verify_api_key +from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): @@ -114,6 +114,7 @@ async def run( ): # Server API Key api_dependency = get_verify_api_key(api_key) if api_key else None + strict_auth = kwargs.get("strict_auth", False) # MCP Server server_type = kwargs.get("server_type") # "stdio" or "sse" or "http" @@ -150,6 +151,10 @@ async def run( allow_headers=["*"], ) + # Add middleware to protect also documentation and spec + if api_key and strict_auth: + main_app.add_middleware(APIKeyMiddleware, api_key=api_key) + if server_type == "sse": main_app.state.server_type = "sse" main_app.state.args = server_command[0] @@ -197,6 +202,10 @@ async def run( sub_app.state.server_type = "sse" sub_app.state.args = server_cfg["url"] + # Add middleware to protect also documentation and spec + if api_key and strict_auth: + sub_app.add_middleware(APIKeyMiddleware, api_key=api_key) + sub_app.state.api_dependency = api_dependency main_app.mount(f"{path_prefix}{server_name}", sub_app) diff --git a/src/mcpo/utils/auth.py b/src/mcpo/utils/auth.py index d32a28e..ecdfe97 100644 --- a/src/mcpo/utils/auth.py +++ b/src/mcpo/utils/auth.py @@ -1,5 +1,8 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from fastapi import Depends, Header, HTTPException, status +from fastapi import Depends, Header, HTTPException, Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +import base64 from passlib.context import CryptContext from datetime import UTC, datetime, timedelta @@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str): return verify_api_key +class APIKeyMiddleware(BaseHTTPMiddleware): + """ + Middleware that enforces Basic or Bearer token authentication for all requests. + """ + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next): + # Skip authentication for OPTIONS requests + if request.method == "OPTIONS": + return await call_next(request) + + # Get authorization header + authorization = request.headers.get("Authorization") + + # Verify API key + try: + # Use the same function that the dependency uses + if not authorization: + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid Authorization header"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + + # Handle Bearer token auth + if authorization.startswith("Bearer "): + token = authorization[7:] # Remove "Bearer " prefix + if token != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid API key"} + ) + # Handle Basic auth + elif authorization.startswith("Basic "): + # Decode the base64 credentials + credentials = authorization[6:] # Remove "Basic " prefix + try: + decoded = base64.b64decode(credentials).decode('utf-8') + # Basic auth format is username:password + username, password = decoded.split(':', 1) + # Any username is allowed, but password must match api_key + if password != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid credentials"} + ) + except Exception: + return JSONResponse( + status_code=401, + content={"detail": "Invalid Basic Authentication format"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + else: + return JSONResponse( + status_code=401, + content={"detail": "Unsupported authorization method"}, + headers={"WWW-Authenticate": "Bearer, Basic"} + ) + + return await call_next(request) + except Exception as e: + return JSONResponse( + status_code=500, + content={"detail": str(e)} + ) + + # def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: # payload = data.copy()