Merge branch 'dev' into main

This commit is contained in:
JinY0ung-Shin 2025-04-28 14:49:49 +09:00 committed by GitHub
commit 6f04974d5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 124 additions and 2 deletions

View File

@ -26,6 +26,10 @@ def main(
Optional[str],
typer.Option("--api-key", "-k", help="API key for authentication"),
] = None,
strict_auth: Annotated[
Optional[bool],
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
] = False,
env: Annotated[
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
] = None,
@ -116,6 +120,7 @@ def main(
host,
port,
api_key=api_key,
strict_auth=strict_auth,
cors_allow_origins=cors_allow_origins,
server_type=server_type,
config_path=config_path,

View File

@ -11,8 +11,9 @@ from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from starlette.routing import Mount
from mcpo.utils.auth import get_verify_api_key
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -126,6 +127,7 @@ async def run(
):
# Server API Key
api_dependency = get_verify_api_key(api_key) if api_key else None
strict_auth = kwargs.get("strict_auth", False)
# MCP Server
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
@ -161,6 +163,10 @@ async def run(
allow_headers=["*"],
)
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
if server_type == "sse":
main_app.state.server_type = "sse"
main_app.state.args = server_command[0]
@ -208,6 +214,10 @@ async def run(
sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"]
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app)

View File

@ -224,3 +224,22 @@ def test_model_caching():
)
assert result_type3 == result_type1 # Should be the same cached object
assert len(_model_cache) == 2 # Only two unique models created
def test_multi_type_property():
schema = {"type": ["string", "number"], "description": "A property with multiple types"}
expected_field = Field(default=..., description="A property with multiple types")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
)
# Check if the resulting type is a Union
assert str(result_type).startswith("typing.Union[")
# Check if both types are in the Union
assert str in result_type.__args__
assert float in result_type.__args__
# Check field properties
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description

View File

@ -1,5 +1,8 @@
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import Depends, Header, HTTPException, status
from fastapi import Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import base64
from passlib.context import CryptContext
from datetime import UTC, datetime, timedelta
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
return verify_api_key
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware that enforces Basic or Bearer token authentication for all requests.
"""
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# Skip authentication for OPTIONS requests
if request.method == "OPTIONS":
return await call_next(request)
# Get authorization header
authorization = request.headers.get("Authorization")
# Verify API key
try:
# Use the same function that the dependency uses
if not authorization:
return JSONResponse(
status_code=401,
content={"detail": "Missing or invalid Authorization header"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
# Handle Bearer token auth
if authorization.startswith("Bearer "):
token = authorization[7:] # Remove "Bearer " prefix
if token != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API key"}
)
# Handle Basic auth
elif authorization.startswith("Basic "):
# Decode the base64 credentials
credentials = authorization[6:] # Remove "Basic " prefix
try:
decoded = base64.b64decode(credentials).decode('utf-8')
# Basic auth format is username:password
username, password = decoded.split(':', 1)
# Any username is allowed, but password must match api_key
if password != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid credentials"}
)
except Exception:
return JSONResponse(
status_code=401,
content={"detail": "Invalid Basic Authentication format"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
else:
return JSONResponse(
status_code=401,
content={"detail": "Unsupported authorization method"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
return await call_next(request)
except Exception as e:
return JSONResponse(
status_code=500,
content={"detail": str(e)}
)
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
# payload = data.copy()

View File

@ -72,6 +72,22 @@ def _process_schema_property(
default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc)
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
if isinstance(prop_type, list):
# Create a Union of all the types
type_hints = []
for type_option in prop_type:
# Create a temporary schema with the single type and process it
temp_schema = dict(prop_schema)
temp_schema["type"] = type_option
type_hint, _ = _process_schema_property(
_model_cache, temp_schema, model_name_prefix, prop_name, False
)
type_hints.append(type_hint)
# Return a Union of all possible types
return Union[tuple(type_hints)], pydantic_field
if prop_type == "object":
nested_properties = prop_schema.get("properties", {})
nested_required = prop_schema.get("required", [])