mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
feat: implement option to make authentication required to access documentation and spec
This commit is contained in:
parent
64e41d1058
commit
9d3e34bfa9
@ -26,6 +26,10 @@ def main(
|
||||
Optional[str],
|
||||
typer.Option("--api-key", "-k", help="API key for authentication"),
|
||||
] = None,
|
||||
strict_auth: Annotated[
|
||||
Optional[bool],
|
||||
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
|
||||
] = False,
|
||||
env: Annotated[
|
||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||
] = None,
|
||||
@ -116,6 +120,7 @@ def main(
|
||||
host,
|
||||
port,
|
||||
api_key=api_key,
|
||||
strict_auth=strict_auth,
|
||||
cors_allow_origins=cors_allow_origins,
|
||||
server_type=server_type,
|
||||
config_path=config_path,
|
||||
|
||||
@ -14,7 +14,7 @@ from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
@ -114,6 +114,7 @@ async def run(
|
||||
):
|
||||
# Server API Key
|
||||
api_dependency = get_verify_api_key(api_key) if api_key else None
|
||||
strict_auth = kwargs.get("strict_auth", False)
|
||||
|
||||
# MCP Server
|
||||
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
||||
@ -150,6 +151,10 @@ async def run(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
if server_type == "sse":
|
||||
main_app.state.server_type = "sse"
|
||||
main_app.state.args = server_command[0]
|
||||
@ -197,6 +202,10 @@ async def run(
|
||||
sub_app.state.server_type = "sse"
|
||||
sub_app.state.args = server_cfg["url"]
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
|
||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import base64
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
|
||||
return verify_api_key
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that enforces Basic or Bearer token authentication for all requests.
|
||||
"""
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for OPTIONS requests
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# Get authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
|
||||
# Verify API key
|
||||
try:
|
||||
# Use the same function that the dependency uses
|
||||
if not authorization:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing or invalid Authorization header"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
# Handle Bearer token auth
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
if token != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API key"}
|
||||
)
|
||||
# Handle Basic auth
|
||||
elif authorization.startswith("Basic "):
|
||||
# Decode the base64 credentials
|
||||
credentials = authorization[6:] # Remove "Basic " prefix
|
||||
try:
|
||||
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||
# Basic auth format is username:password
|
||||
username, password = decoded.split(':', 1)
|
||||
# Any username is allowed, but password must match api_key
|
||||
if password != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid credentials"}
|
||||
)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Basic Authentication format"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Unsupported authorization method"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": str(e)}
|
||||
)
|
||||
|
||||
|
||||
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||
# payload = data.copy()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user