Merge pull request #83 from alpha-pet/feat-extend-api-key-protection
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled

feat: implement option to make authentication required for everything
This commit is contained in:
Tim Jaeryang Baek 2025-04-19 11:38:40 -07:00 committed by GitHub
commit 393f2e724b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 88 additions and 2 deletions

View File

@ -26,6 +26,10 @@ def main(
Optional[str], Optional[str],
typer.Option("--api-key", "-k", help="API key for authentication"), typer.Option("--api-key", "-k", help="API key for authentication"),
] = None, ] = None,
strict_auth: Annotated[
Optional[bool],
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
] = False,
env: Annotated[ env: Annotated[
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables") Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
] = None, ] = None,
@ -116,6 +120,7 @@ def main(
host, host,
port, port,
api_key=api_key, api_key=api_key,
strict_auth=strict_auth,
cors_allow_origins=cors_allow_origins, cors_allow_origins=cors_allow_origins,
server_type=server_type, server_type=server_type,
config_path=config_path, config_path=config_path,

View File

@ -14,7 +14,7 @@ from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -114,6 +114,7 @@ async def run(
): ):
# Server API Key # Server API Key
api_dependency = get_verify_api_key(api_key) if api_key else None api_dependency = get_verify_api_key(api_key) if api_key else None
strict_auth = kwargs.get("strict_auth", False)
# MCP Server # MCP Server
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http" server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
@ -150,6 +151,10 @@ async def run(
allow_headers=["*"], allow_headers=["*"],
) )
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
if server_type == "sse": if server_type == "sse":
main_app.state.server_type = "sse" main_app.state.server_type = "sse"
main_app.state.args = server_command[0] main_app.state.args = server_command[0]
@ -197,6 +202,10 @@ async def run(
sub_app.state.server_type = "sse" sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"] sub_app.state.args = server_cfg["url"]
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
sub_app.state.api_dependency = api_dependency sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app) main_app.mount(f"{path_prefix}{server_name}", sub_app)

View File

@ -1,5 +1,8 @@
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import Depends, Header, HTTPException, status from fastapi import Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import base64
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
return verify_api_key return verify_api_key
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware that enforces Basic or Bearer token authentication for all requests.
"""
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# Skip authentication for OPTIONS requests
if request.method == "OPTIONS":
return await call_next(request)
# Get authorization header
authorization = request.headers.get("Authorization")
# Verify API key
try:
# Use the same function that the dependency uses
if not authorization:
return JSONResponse(
status_code=401,
content={"detail": "Missing or invalid Authorization header"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
# Handle Bearer token auth
if authorization.startswith("Bearer "):
token = authorization[7:] # Remove "Bearer " prefix
if token != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API key"}
)
# Handle Basic auth
elif authorization.startswith("Basic "):
# Decode the base64 credentials
credentials = authorization[6:] # Remove "Basic " prefix
try:
decoded = base64.b64decode(credentials).decode('utf-8')
# Basic auth format is username:password
username, password = decoded.split(':', 1)
# Any username is allowed, but password must match api_key
if password != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid credentials"}
)
except Exception:
return JSONResponse(
status_code=401,
content={"detail": "Invalid Basic Authentication format"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
else:
return JSONResponse(
status_code=401,
content={"detail": "Unsupported authorization method"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
return await call_next(request)
except Exception as e:
return JSONResponse(
status_code=500,
content={"detail": str(e)}
)
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: # def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
# payload = data.copy() # payload = data.copy()