Merge pull request #108 from open-webui/dev
Some checks failed
Release / release (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled

0.0.13
This commit is contained in:
Tim Jaeryang Baek 2025-04-30 22:53:11 -07:00 committed by GitHub
commit e37d0ebd27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 312 additions and 37 deletions

View File

@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
# Changelog
All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog,
and this project adheres to Semantic Versioning.
## [0.0.13] - 2025-05-01
### Added
- 🧪 **Support for Mixed and Union Types (anyOf/nullables)**: mcpo now accurately exposes OpenAPI schemas with anyOf compositions and nullable fields.
- 🧷 **Authentication-Required Docs Access with --strict-auth**: When enabled, the new --strict-auth option restricts access to both the tool endpoints and their interactive documentation pages—ensuring sensitive internal services arent inadvertently exposed to unauthenticated users or LLMs.
- 🧬 **Custom Schema Definitions for Complex Models**: Developers can now register custom BaseModel schemas with arbitrary nesting and field variants, allowing precise OpenAPI representations of deeply structured payloads—ensuring crystal-clear docs and compatibility for multi-layered data workflows.
- 🔄 **Smarter Schema Inference Across Data Types**: Schema generation has been enhanced to gracefully handle nested unions, nulls, and fallback types, dramatically improving accuracy in tools using variable output formats or flexible data contracts.
## [0.0.12] - 2025-04-14
### Fixed

View File

@ -1,6 +1,6 @@
[project]
name = "mcpo"
version = "0.0.12"
version = "0.0.13"
description = "A simple, secure MCP-to-OpenAPI proxy server"
authors = [
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }

View File

@ -26,6 +26,10 @@ def main(
Optional[str],
typer.Option("--api-key", "-k", help="API key for authentication"),
] = None,
strict_auth: Annotated[
Optional[bool],
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
] = False,
env: Annotated[
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
] = None,
@ -116,6 +120,7 @@ def main(
host,
port,
api_key=api_key,
strict_auth=strict_auth,
cors_allow_origins=cors_allow_origins,
server_type=server_type,
config_path=config_path,

View File

@ -4,17 +4,16 @@ from contextlib import AsyncExitStack, asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from starlette.routing import Mount
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -37,21 +36,31 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
for tool in tools:
endpoint_name = tool.name
endpoint_description = tool.description
schema = tool.inputSchema
required_fields = schema.get("required", [])
properties = schema.get("properties", {})
inputSchema = tool.inputSchema
outputSchema = getattr(tool, "outputSchema", None)
form_model_name = f"{endpoint_name}_form_model"
model_fields = get_model_fields(
form_model_name,
properties,
required_fields,
form_model_fields = get_model_fields(
f"{endpoint_name}_form_model",
inputSchema.get("properties", {}),
inputSchema.get("required", []),
inputSchema.get("$defs", {}),
)
response_model_fields = None
if outputSchema:
response_model_fields = get_model_fields(
f"{endpoint_name}_response_model",
outputSchema.get("properties", {}),
outputSchema.get("required", []),
outputSchema.get("$defs", {}),
)
tool_handler = get_tool_handler(
session, endpoint_name, form_model_name, model_fields
session,
endpoint_name,
form_model_fields,
response_model_fields,
)
app.post(
@ -98,7 +107,10 @@ async def lifespan(app: FastAPI):
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
if server_type == "sse":
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
async with sse_client(url=args[0], sse_read_timeout=None) as (
reader,
writer,
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
@ -114,6 +126,7 @@ async def run(
):
# Server API Key
api_dependency = get_verify_api_key(api_key) if api_key else None
strict_auth = kwargs.get("strict_auth", False)
# MCP Server
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
@ -132,7 +145,6 @@ async def run(
ssl_certfile = kwargs.get("ssl_certfile")
ssl_keyfile = kwargs.get("ssl_keyfile")
path_prefix = kwargs.get("path_prefix") or "/"
main_app = FastAPI(
title=name,
description=description,
@ -150,6 +162,10 @@ async def run(
allow_headers=["*"],
)
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
if server_type == "sse":
main_app.state.server_type = "sse"
main_app.state.args = server_command[0]
@ -197,6 +213,10 @@ async def run(
sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"]
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
sub_app.state.api_dependency = api_dependency
main_app.mount(f"{path_prefix}{server_name}", sub_app)

View File

@ -1,6 +1,6 @@
import pytest
from pydantic import BaseModel, Field
from typing import Any, List, Dict
from typing import Any, List, Dict, Union
from mcpo.utils.main import _process_schema_property
@ -64,6 +64,17 @@ def test_process_simple_number():
assert result_field.default == expected_field.default
def test_process_null():
schema = {"type": "null"}
expected_type = None
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
def test_process_unknown_type():
schema = {"type": "unknown"}
expected_type = Any
@ -224,3 +235,78 @@ def test_model_caching():
)
assert result_type3 == result_type1 # Should be the same cached object
assert len(_model_cache) == 2 # Only two unique models created
def test_multi_type_property_with_list():
schema = {
"type": ["string", "number"],
"description": "A property with multiple types",
}
expected_field = Field(default=..., description="A property with multiple types")
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
)
# Check if the resulting type is a Union
assert str(result_type).startswith("typing.Union[")
# Check if both types are in the Union
assert str in result_type.__args__
assert float in result_type.__args__
# Check field properties
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
def test_multi_type_property_with_any_of():
schema = {
"anyOf": [
{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to call",
},
"arguments": {
"type": "string",
"description": "The arguments to pass to the function, as a JSON string",
"default": "{}",
},
},
"required": ["name"],
},
{
"type": "object",
"properties": {
"function_id": {
"type": "int",
"description": "The id of the function to call",
},
},
"required": ["function_id"],
},
{
"type": "string",
"enum": ["auto", "none"],
"description": "Control function calling behavior",
},
],
"description": "A property with multiple types",
}
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "multi_type", True
)
# Check if the resulting type is a Union
assert result_type.__origin__ == Union
# Check if the Union has the correct number of types
assert len(result_type.__args__) == 3
assert len(result_type.__args__[0].model_fields) == 2
assert len(result_type.__args__[1].model_fields) == 1
assert result_type.__args__[2] is str
# assert result_field parameter config
assert result_field.description == "A property with multiple types"

View File

@ -1,5 +1,8 @@
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import Depends, Header, HTTPException, status
from fastapi import Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import base64
from passlib.context import CryptContext
from datetime import UTC, datetime, timedelta
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
return verify_api_key
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware that enforces Basic or Bearer token authentication for all requests.
"""
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# Skip authentication for OPTIONS requests
if request.method == "OPTIONS":
return await call_next(request)
# Get authorization header
authorization = request.headers.get("Authorization")
# Verify API key
try:
# Use the same function that the dependency uses
if not authorization:
return JSONResponse(
status_code=401,
content={"detail": "Missing or invalid Authorization header"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
# Handle Bearer token auth
if authorization.startswith("Bearer "):
token = authorization[7:] # Remove "Bearer " prefix
if token != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API key"}
)
# Handle Basic auth
elif authorization.startswith("Basic "):
# Decode the base64 credentials
credentials = authorization[6:] # Remove "Basic " prefix
try:
decoded = base64.b64decode(credentials).decode('utf-8')
# Basic auth format is username:password
username, password = decoded.split(':', 1)
# Any username is allowed, but password must match api_key
if password != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid credentials"}
)
except Exception:
return JSONResponse(
status_code=401,
content={"detail": "Invalid Basic Authentication format"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
else:
return JSONResponse(
status_code=401,
content={"detail": "Unsupported authorization method"},
headers={"WWW-Authenticate": "Bearer, Basic"}
)
return await call_next(request)
except Exception as e:
return JSONResponse(
status_code=500,
content={"detail": str(e)}
)
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
# payload = data.copy()

View File

@ -1,12 +1,22 @@
from typing import Any, Dict, List, Type, Union, ForwardRef
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
import json
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
from fastapi import HTTPException
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
from mcp import ClientSession, types
from mcp.types import (
CallToolResult,
PARSE_ERROR,
INVALID_REQUEST,
METHOD_NOT_FOUND,
INVALID_PARAMS,
INTERNAL_ERROR,
)
from mcp.shared.exceptions import McpError
import json
from pydantic import Field, create_model
from pydantic.fields import FieldInfo
MCP_ERROR_TO_HTTP_STATUS = {
PARSE_ERROR: 400,
@ -44,6 +54,7 @@ def _process_schema_property(
model_name_prefix: str,
prop_name: str,
is_required: bool,
schema_defs: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
@ -53,11 +64,49 @@ def _process_schema_property(
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
if "$ref" in prop_schema:
ref = prop_schema["$ref"]
ref = ref.split("/")[-1]
assert ref in schema_defs, "Custom field not found"
prop_schema = schema_defs[ref]
prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc)
# Handle the case where prop_type is missing but 'anyOf' key exists
# In this case, use data type from 'anyOf' to determine the type hint
if "anyOf" in prop_schema:
type_hints = []
for i, schema_option in enumerate(prop_schema["anyOf"]):
type_hint, _ = _process_schema_property(
_model_cache,
schema_option,
f"{model_name_prefix}_{prop_name}",
f"choice_{i}",
False,
)
type_hints.append(type_hint)
return Union[tuple(type_hints)], pydantic_field
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
if isinstance(prop_type, list):
# Create a Union of all the types
type_hints = []
for type_option in prop_type:
# Create a temporary schema with the single type and process it
temp_schema = dict(prop_schema)
temp_schema["type"] = type_option
type_hint, _ = _process_schema_property(
_model_cache, temp_schema, model_name_prefix, prop_name, False
)
type_hints.append(type_hint)
# Return a Union of all possible types
return Union[tuple(type_hints)], pydantic_field
if prop_type == "object":
nested_properties = prop_schema.get("properties", {})
nested_required = prop_schema.get("required", [])
@ -73,7 +122,12 @@ def _process_schema_property(
for name, schema in nested_properties.items():
is_nested_required = name in nested_required
nested_type_hint, nested_pydantic_field = _process_schema_property(
_model_cache, schema, nested_model_name, name, is_nested_required
_model_cache,
schema,
nested_model_name,
name,
is_nested_required,
schema_defs,
)
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
@ -98,7 +152,8 @@ def _process_schema_property(
items_schema,
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level
False, # Items aren't required at this level,
schema_defs,
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
@ -111,11 +166,13 @@ def _process_schema_property(
return bool, pydantic_field
elif prop_type == "number":
return float, pydantic_field
elif prop_type == "null":
return None, pydantic_field
else:
return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields):
def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
model_fields = {}
_model_cache: Dict[str, Type] = {}
@ -123,21 +180,36 @@ def get_model_fields(form_model_name, properties, required_fields):
for param_name, param_schema in properties.items():
is_required = param_name in required_fields
python_type_hint, pydantic_field_info = _process_schema_property(
_model_cache, param_schema, form_model_name, param_name, is_required
_model_cache,
param_schema,
form_model_name,
param_name,
is_required,
schema_defs,
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
return model_fields
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
if model_fields:
FormModel = create_model(form_model_name, **model_fields)
def get_tool_handler(
session,
endpoint_name,
form_model_fields,
response_model_fields=None,
):
if form_model_fields:
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
ResponseModel = (
create_model(f"{endpoint_name}_response_model", **response_model_fields)
if response_model_fields
else Any
)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
async def tool(form_data: FormModel) -> ResponseModel:
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
@ -158,7 +230,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e:
@ -206,7 +280,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e: