mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge pull request #83 from alpha-pet/feat-extend-api-key-protection
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
feat: implement option to make authentication required for everything
This commit is contained in:
commit
393f2e724b
@ -26,6 +26,10 @@ def main(
|
|||||||
Optional[str],
|
Optional[str],
|
||||||
typer.Option("--api-key", "-k", help="API key for authentication"),
|
typer.Option("--api-key", "-k", help="API key for authentication"),
|
||||||
] = None,
|
] = None,
|
||||||
|
strict_auth: Annotated[
|
||||||
|
Optional[bool],
|
||||||
|
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
|
||||||
|
] = False,
|
||||||
env: Annotated[
|
env: Annotated[
|
||||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||||
] = None,
|
] = None,
|
||||||
@ -116,6 +120,7 @@ def main(
|
|||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
strict_auth=strict_auth,
|
||||||
cors_allow_origins=cors_allow_origins,
|
cors_allow_origins=cors_allow_origins,
|
||||||
server_type=server_type,
|
server_type=server_type,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from mcp.client.stdio import stdio_client
|
|||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||||
from mcpo.utils.auth import get_verify_api_key
|
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||||
|
|
||||||
|
|
||||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||||
@ -114,6 +114,7 @@ async def run(
|
|||||||
):
|
):
|
||||||
# Server API Key
|
# Server API Key
|
||||||
api_dependency = get_verify_api_key(api_key) if api_key else None
|
api_dependency = get_verify_api_key(api_key) if api_key else None
|
||||||
|
strict_auth = kwargs.get("strict_auth", False)
|
||||||
|
|
||||||
# MCP Server
|
# MCP Server
|
||||||
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
||||||
@ -150,6 +151,10 @@ async def run(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add middleware to protect also documentation and spec
|
||||||
|
if api_key and strict_auth:
|
||||||
|
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||||
|
|
||||||
if server_type == "sse":
|
if server_type == "sse":
|
||||||
main_app.state.server_type = "sse"
|
main_app.state.server_type = "sse"
|
||||||
main_app.state.args = server_command[0]
|
main_app.state.args = server_command[0]
|
||||||
@ -197,6 +202,10 @@ async def run(
|
|||||||
sub_app.state.server_type = "sse"
|
sub_app.state.server_type = "sse"
|
||||||
sub_app.state.args = server_cfg["url"]
|
sub_app.state.args = server_cfg["url"]
|
||||||
|
|
||||||
|
# Add middleware to protect also documentation and spec
|
||||||
|
if api_key and strict_auth:
|
||||||
|
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||||
|
|
||||||
sub_app.state.api_dependency = api_dependency
|
sub_app.state.api_dependency = api_dependency
|
||||||
|
|
||||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
import base64
|
||||||
|
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
|
|||||||
return verify_api_key
|
return verify_api_key
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
Middleware that enforces Basic or Bearer token authentication for all requests.
|
||||||
|
"""
|
||||||
|
def __init__(self, app, api_key: str):
|
||||||
|
super().__init__(app)
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Skip authentication for OPTIONS requests
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Get authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
|
||||||
|
# Verify API key
|
||||||
|
try:
|
||||||
|
# Use the same function that the dependency uses
|
||||||
|
if not authorization:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Missing or invalid Authorization header"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle Bearer token auth
|
||||||
|
if authorization.startswith("Bearer "):
|
||||||
|
token = authorization[7:] # Remove "Bearer " prefix
|
||||||
|
if token != self.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Invalid API key"}
|
||||||
|
)
|
||||||
|
# Handle Basic auth
|
||||||
|
elif authorization.startswith("Basic "):
|
||||||
|
# Decode the base64 credentials
|
||||||
|
credentials = authorization[6:] # Remove "Basic " prefix
|
||||||
|
try:
|
||||||
|
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||||
|
# Basic auth format is username:password
|
||||||
|
username, password = decoded.split(':', 1)
|
||||||
|
# Any username is allowed, but password must match api_key
|
||||||
|
if password != self.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Invalid credentials"}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Invalid Basic Authentication format"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Unsupported authorization method"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||||
# payload = data.copy()
|
# payload = data.copy()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user