mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge branch 'dev' into main
This commit is contained in:
commit
6f04974d5d
@ -26,6 +26,10 @@ def main(
|
||||
Optional[str],
|
||||
typer.Option("--api-key", "-k", help="API key for authentication"),
|
||||
] = None,
|
||||
strict_auth: Annotated[
|
||||
Optional[bool],
|
||||
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
|
||||
] = False,
|
||||
env: Annotated[
|
||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||
] = None,
|
||||
@ -116,6 +120,7 @@ def main(
|
||||
host,
|
||||
port,
|
||||
api_key=api_key,
|
||||
strict_auth=strict_auth,
|
||||
cors_allow_origins=cors_allow_origins,
|
||||
server_type=server_type,
|
||||
config_path=config_path,
|
||||
|
||||
@ -11,8 +11,9 @@ from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from starlette.routing import Mount
|
||||
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
@ -126,6 +127,7 @@ async def run(
|
||||
):
|
||||
# Server API Key
|
||||
api_dependency = get_verify_api_key(api_key) if api_key else None
|
||||
strict_auth = kwargs.get("strict_auth", False)
|
||||
|
||||
# MCP Server
|
||||
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
||||
@ -161,6 +163,10 @@ async def run(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
if server_type == "sse":
|
||||
main_app.state.server_type = "sse"
|
||||
main_app.state.args = server_command[0]
|
||||
@ -208,6 +214,10 @@ async def run(
|
||||
sub_app.state.server_type = "sse"
|
||||
sub_app.state.args = server_cfg["url"]
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
|
||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||
|
||||
@ -224,3 +224,22 @@ def test_model_caching():
|
||||
)
|
||||
assert result_type3 == result_type1 # Should be the same cached object
|
||||
assert len(_model_cache) == 2 # Only two unique models created
|
||||
|
||||
|
||||
def test_multi_type_property():
|
||||
schema = {"type": ["string", "number"], "description": "A property with multiple types"}
|
||||
expected_field = Field(default=..., description="A property with multiple types")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "multi_type", True
|
||||
)
|
||||
|
||||
# Check if the resulting type is a Union
|
||||
assert str(result_type).startswith("typing.Union[")
|
||||
|
||||
# Check if both types are in the Union
|
||||
assert str in result_type.__args__
|
||||
assert float in result_type.__args__
|
||||
|
||||
# Check field properties
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import base64
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
|
||||
return verify_api_key
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that enforces Basic or Bearer token authentication for all requests.
|
||||
"""
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for OPTIONS requests
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# Get authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
|
||||
# Verify API key
|
||||
try:
|
||||
# Use the same function that the dependency uses
|
||||
if not authorization:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing or invalid Authorization header"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
# Handle Bearer token auth
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
if token != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API key"}
|
||||
)
|
||||
# Handle Basic auth
|
||||
elif authorization.startswith("Basic "):
|
||||
# Decode the base64 credentials
|
||||
credentials = authorization[6:] # Remove "Basic " prefix
|
||||
try:
|
||||
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||
# Basic auth format is username:password
|
||||
username, password = decoded.split(':', 1)
|
||||
# Any username is allowed, but password must match api_key
|
||||
if password != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid credentials"}
|
||||
)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Basic Authentication format"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Unsupported authorization method"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": str(e)}
|
||||
)
|
||||
|
||||
|
||||
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||
# payload = data.copy()
|
||||
|
||||
|
||||
@ -72,6 +72,22 @@ def _process_schema_property(
|
||||
default_value = ... if is_required else prop_schema.get("default", None)
|
||||
pydantic_field = Field(default=default_value, description=prop_desc)
|
||||
|
||||
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
|
||||
if isinstance(prop_type, list):
|
||||
# Create a Union of all the types
|
||||
type_hints = []
|
||||
for type_option in prop_type:
|
||||
# Create a temporary schema with the single type and process it
|
||||
temp_schema = dict(prop_schema)
|
||||
temp_schema["type"] = type_option
|
||||
type_hint, _ = _process_schema_property(
|
||||
_model_cache, temp_schema, model_name_prefix, prop_name, False
|
||||
)
|
||||
type_hints.append(type_hint)
|
||||
|
||||
# Return a Union of all possible types
|
||||
return Union[tuple(type_hints)], pydantic_field
|
||||
|
||||
if prop_type == "object":
|
||||
nested_properties = prop_schema.get("properties", {})
|
||||
nested_required = prop_schema.get("required", [])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user