feat: implement option to make authentication required to access documentation and spec

This commit is contained in:
Thomas Rehn 2025-04-17 16:23:52 +02:00
parent 64e41d1058
commit 9d3e34bfa9
3 changed files with 88 additions and 2 deletions

View File

@ -26,6 +26,10 @@ def main(
Optional[str],
typer.Option("--api-key", "-k", help="API key for authentication"),
] = None,
strict_auth: Annotated[
Optional[bool],
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
] = False,
env: Annotated[
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
] = None,
@ -116,6 +120,7 @@ def main(
host,
port,
api_key=api_key,
strict_auth=strict_auth,
cors_allow_origins=cors_allow_origins,
server_type=server_type,
config_path=config_path,

View File

@ -14,7 +14,7 @@ from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -114,6 +114,7 @@ async def run(
):
# Server API Key
api_dependency = get_verify_api_key(api_key) if api_key else None
strict_auth = kwargs.get("strict_auth", False)
# MCP Server
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
@ -150,6 +151,10 @@ async def run(
allow_headers=["*"],
)
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
if server_type == "sse":
main_app.state.server_type = "sse"
main_app.state.args = server_command[0]
@ -197,6 +202,10 @@ async def run(
sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"]
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app)

View File

@ -1,5 +1,8 @@
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import Depends, Header, HTTPException, status
from fastapi import Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import base64
from passlib.context import CryptContext
from datetime import UTC, datetime, timedelta
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
return verify_api_key
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware that enforces Basic or Bearer token authentication for all requests.
"""
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# Skip authentication for OPTIONS requests
if request.method == "OPTIONS":
return await call_next(request)
# Get authorization header
authorization = request.headers.get("Authorization")
# Verify API key
try:
# Use the same function that the dependency uses
if not authorization:
return JSONResponse(
status_code=401,
content={"detail": "Missing or invalid Authorization header"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
# Handle Bearer token auth
if authorization.startswith("Bearer "):
token = authorization[7:] # Remove "Bearer " prefix
if token != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API key"}
)
# Handle Basic auth
elif authorization.startswith("Basic "):
# Decode the base64 credentials
credentials = authorization[6:] # Remove "Basic " prefix
try:
decoded = base64.b64decode(credentials).decode('utf-8')
# Basic auth format is username:password
username, password = decoded.split(':', 1)
# Any username is allowed, but password must match api_key
if password != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid credentials"}
)
except Exception:
return JSONResponse(
status_code=401,
content={"detail": "Invalid Basic Authentication format"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
else:
return JSONResponse(
status_code=401,
content={"detail": "Unsupported authorization method"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
return await call_next(request)
except Exception as e:
return JSONResponse(
status_code=500,
content={"detail": str(e)}
)
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
# payload = data.copy()