mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge pull request #108 from open-webui/dev
Some checks failed
Release / release (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Some checks failed
Release / release (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
0.0.13
This commit is contained in:
commit
e37d0ebd27
16
CHANGELOG.md
16
CHANGELOG.md
@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on Keep a Changelog,
|
||||
and this project adheres to Semantic Versioning.
|
||||
|
||||
## [0.0.13] - 2025-05-01
|
||||
|
||||
### Added
|
||||
|
||||
- 🧪 **Support for Mixed and Union Types (anyOf/nullables)**: mcpo now accurately exposes OpenAPI schemas with anyOf compositions and nullable fields.
|
||||
- 🧷 **Authentication-Required Docs Access with --strict-auth**: When enabled, the new --strict-auth option restricts access to both the tool endpoints and their interactive documentation pages—ensuring sensitive internal services aren’t inadvertently exposed to unauthenticated users or LLMs.
|
||||
- 🧬 **Custom Schema Definitions for Complex Models**: Developers can now register custom BaseModel schemas with arbitrary nesting and field variants, allowing precise OpenAPI representations of deeply structured payloads—ensuring crystal-clear docs and compatibility for multi-layered data workflows.
|
||||
- 🔄 **Smarter Schema Inference Across Data Types**: Schema generation has been enhanced to gracefully handle nested unions, nulls, and fallback types, dramatically improving accuracy in tools using variable output formats or flexible data contracts.
|
||||
|
||||
## [0.0.12] - 2025-04-14
|
||||
|
||||
### Fixed
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mcpo"
|
||||
version = "0.0.12"
|
||||
version = "0.0.13"
|
||||
description = "A simple, secure MCP-to-OpenAPI proxy server"
|
||||
authors = [
|
||||
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
|
||||
|
@ -26,6 +26,10 @@ def main(
|
||||
Optional[str],
|
||||
typer.Option("--api-key", "-k", help="API key for authentication"),
|
||||
] = None,
|
||||
strict_auth: Annotated[
|
||||
Optional[bool],
|
||||
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
|
||||
] = False,
|
||||
env: Annotated[
|
||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||
] = None,
|
||||
@ -116,6 +120,7 @@ def main(
|
||||
host,
|
||||
port,
|
||||
api_key=api_key,
|
||||
strict_auth=strict_auth,
|
||||
cors_allow_origins=cors_allow_origins,
|
||||
server_type=server_type,
|
||||
config_path=config_path,
|
||||
|
@ -4,17 +4,16 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from starlette.routing import Mount
|
||||
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
@ -37,21 +36,31 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
for tool in tools:
|
||||
endpoint_name = tool.name
|
||||
endpoint_description = tool.description
|
||||
schema = tool.inputSchema
|
||||
|
||||
required_fields = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
inputSchema = tool.inputSchema
|
||||
outputSchema = getattr(tool, "outputSchema", None)
|
||||
|
||||
form_model_name = f"{endpoint_name}_form_model"
|
||||
|
||||
model_fields = get_model_fields(
|
||||
form_model_name,
|
||||
properties,
|
||||
required_fields,
|
||||
form_model_fields = get_model_fields(
|
||||
f"{endpoint_name}_form_model",
|
||||
inputSchema.get("properties", {}),
|
||||
inputSchema.get("required", []),
|
||||
inputSchema.get("$defs", {}),
|
||||
)
|
||||
|
||||
response_model_fields = None
|
||||
if outputSchema:
|
||||
response_model_fields = get_model_fields(
|
||||
f"{endpoint_name}_response_model",
|
||||
outputSchema.get("properties", {}),
|
||||
outputSchema.get("required", []),
|
||||
outputSchema.get("$defs", {}),
|
||||
)
|
||||
|
||||
tool_handler = get_tool_handler(
|
||||
session, endpoint_name, form_model_name, model_fields
|
||||
session,
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields,
|
||||
)
|
||||
|
||||
app.post(
|
||||
@ -98,7 +107,10 @@ async def lifespan(app: FastAPI):
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
yield
|
||||
if server_type == "sse":
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (
|
||||
reader,
|
||||
writer,
|
||||
):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
@ -114,6 +126,7 @@ async def run(
|
||||
):
|
||||
# Server API Key
|
||||
api_dependency = get_verify_api_key(api_key) if api_key else None
|
||||
strict_auth = kwargs.get("strict_auth", False)
|
||||
|
||||
# MCP Server
|
||||
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
||||
@ -132,7 +145,6 @@ async def run(
|
||||
ssl_certfile = kwargs.get("ssl_certfile")
|
||||
ssl_keyfile = kwargs.get("ssl_keyfile")
|
||||
path_prefix = kwargs.get("path_prefix") or "/"
|
||||
|
||||
main_app = FastAPI(
|
||||
title=name,
|
||||
description=description,
|
||||
@ -150,6 +162,10 @@ async def run(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
if server_type == "sse":
|
||||
main_app.state.server_type = "sse"
|
||||
main_app.state.args = server_command[0]
|
||||
@ -197,6 +213,10 @@ async def run(
|
||||
sub_app.state.server_type = "sse"
|
||||
sub_app.state.args = server_cfg["url"]
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
|
||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from mcpo.utils.main import _process_schema_property
|
||||
|
||||
@ -64,6 +64,17 @@ def test_process_simple_number():
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_null():
|
||||
schema = {"type": "null"}
|
||||
expected_type = None
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_unknown_type():
|
||||
schema = {"type": "unknown"}
|
||||
expected_type = Any
|
||||
@ -224,3 +235,78 @@ def test_model_caching():
|
||||
)
|
||||
assert result_type3 == result_type1 # Should be the same cached object
|
||||
assert len(_model_cache) == 2 # Only two unique models created
|
||||
|
||||
|
||||
def test_multi_type_property_with_list():
|
||||
schema = {
|
||||
"type": ["string", "number"],
|
||||
"description": "A property with multiple types",
|
||||
}
|
||||
expected_field = Field(default=..., description="A property with multiple types")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "multi_type", True
|
||||
)
|
||||
|
||||
# Check if the resulting type is a Union
|
||||
assert str(result_type).startswith("typing.Union[")
|
||||
|
||||
# Check if both types are in the Union
|
||||
assert str in result_type.__args__
|
||||
assert float in result_type.__args__
|
||||
|
||||
# Check field properties
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
|
||||
|
||||
def test_multi_type_property_with_any_of():
|
||||
schema = {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the function to call",
|
||||
},
|
||||
"arguments": {
|
||||
"type": "string",
|
||||
"description": "The arguments to pass to the function, as a JSON string",
|
||||
"default": "{}",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function_id": {
|
||||
"type": "int",
|
||||
"description": "The id of the function to call",
|
||||
},
|
||||
},
|
||||
"required": ["function_id"],
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"enum": ["auto", "none"],
|
||||
"description": "Control function calling behavior",
|
||||
},
|
||||
],
|
||||
"description": "A property with multiple types",
|
||||
}
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "multi_type", True
|
||||
)
|
||||
|
||||
# Check if the resulting type is a Union
|
||||
assert result_type.__origin__ == Union
|
||||
|
||||
# Check if the Union has the correct number of types
|
||||
assert len(result_type.__args__) == 3
|
||||
assert len(result_type.__args__[0].model_fields) == 2
|
||||
assert len(result_type.__args__[1].model_fields) == 1
|
||||
assert result_type.__args__[2] is str
|
||||
|
||||
# assert result_field parameter config
|
||||
assert result_field.description == "A property with multiple types"
|
||||
|
@ -1,5 +1,8 @@
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import base64
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
|
||||
return verify_api_key
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that enforces Basic or Bearer token authentication for all requests.
|
||||
"""
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for OPTIONS requests
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# Get authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
|
||||
# Verify API key
|
||||
try:
|
||||
# Use the same function that the dependency uses
|
||||
if not authorization:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing or invalid Authorization header"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
# Handle Bearer token auth
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
if token != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API key"}
|
||||
)
|
||||
# Handle Basic auth
|
||||
elif authorization.startswith("Basic "):
|
||||
# Decode the base64 credentials
|
||||
credentials = authorization[6:] # Remove "Basic " prefix
|
||||
try:
|
||||
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||
# Basic auth format is username:password
|
||||
username, password = decoded.split(':', 1)
|
||||
# Any username is allowed, but password must match api_key
|
||||
if password != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid credentials"}
|
||||
)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Basic Authentication format"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Unsupported authorization method"},
|
||||
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": str(e)}
|
||||
)
|
||||
|
||||
|
||||
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||
# payload = data.copy()
|
||||
|
||||
|
@ -1,12 +1,22 @@
|
||||
from typing import Any, Dict, List, Type, Union, ForwardRef
|
||||
from pydantic import create_model, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from mcp import ClientSession, types
|
||||
import json
|
||||
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
PARSE_ERROR,
|
||||
INVALID_REQUEST,
|
||||
METHOD_NOT_FOUND,
|
||||
INVALID_PARAMS,
|
||||
INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
import json
|
||||
from pydantic import Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
MCP_ERROR_TO_HTTP_STATUS = {
|
||||
PARSE_ERROR: 400,
|
||||
@ -44,6 +54,7 @@ def _process_schema_property(
|
||||
model_name_prefix: str,
|
||||
prop_name: str,
|
||||
is_required: bool,
|
||||
schema_defs: Optional[Dict] = None,
|
||||
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
|
||||
"""
|
||||
Recursively processes a schema property to determine its Python type hint
|
||||
@ -53,11 +64,49 @@ def _process_schema_property(
|
||||
A tuple containing (python_type_hint, pydantic_field).
|
||||
The pydantic_field contains default value and description.
|
||||
"""
|
||||
if "$ref" in prop_schema:
|
||||
ref = prop_schema["$ref"]
|
||||
ref = ref.split("/")[-1]
|
||||
assert ref in schema_defs, "Custom field not found"
|
||||
prop_schema = schema_defs[ref]
|
||||
|
||||
prop_type = prop_schema.get("type")
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
|
||||
default_value = ... if is_required else prop_schema.get("default", None)
|
||||
pydantic_field = Field(default=default_value, description=prop_desc)
|
||||
|
||||
# Handle the case where prop_type is missing but 'anyOf' key exists
|
||||
# In this case, use data type from 'anyOf' to determine the type hint
|
||||
if "anyOf" in prop_schema:
|
||||
type_hints = []
|
||||
for i, schema_option in enumerate(prop_schema["anyOf"]):
|
||||
type_hint, _ = _process_schema_property(
|
||||
_model_cache,
|
||||
schema_option,
|
||||
f"{model_name_prefix}_{prop_name}",
|
||||
f"choice_{i}",
|
||||
False,
|
||||
)
|
||||
type_hints.append(type_hint)
|
||||
return Union[tuple(type_hints)], pydantic_field
|
||||
|
||||
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
|
||||
if isinstance(prop_type, list):
|
||||
# Create a Union of all the types
|
||||
type_hints = []
|
||||
for type_option in prop_type:
|
||||
# Create a temporary schema with the single type and process it
|
||||
temp_schema = dict(prop_schema)
|
||||
temp_schema["type"] = type_option
|
||||
type_hint, _ = _process_schema_property(
|
||||
_model_cache, temp_schema, model_name_prefix, prop_name, False
|
||||
)
|
||||
type_hints.append(type_hint)
|
||||
|
||||
# Return a Union of all possible types
|
||||
return Union[tuple(type_hints)], pydantic_field
|
||||
|
||||
if prop_type == "object":
|
||||
nested_properties = prop_schema.get("properties", {})
|
||||
nested_required = prop_schema.get("required", [])
|
||||
@ -73,7 +122,12 @@ def _process_schema_property(
|
||||
for name, schema in nested_properties.items():
|
||||
is_nested_required = name in nested_required
|
||||
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
||||
_model_cache, schema, nested_model_name, name, is_nested_required
|
||||
_model_cache,
|
||||
schema,
|
||||
nested_model_name,
|
||||
name,
|
||||
is_nested_required,
|
||||
schema_defs,
|
||||
)
|
||||
|
||||
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
||||
@ -98,7 +152,8 @@ def _process_schema_property(
|
||||
items_schema,
|
||||
f"{model_name_prefix}_{prop_name}",
|
||||
"item",
|
||||
False, # Items aren't required at this level
|
||||
False, # Items aren't required at this level,
|
||||
schema_defs,
|
||||
)
|
||||
list_type_hint = List[item_type_hint]
|
||||
return list_type_hint, pydantic_field
|
||||
@ -111,11 +166,13 @@ def _process_schema_property(
|
||||
return bool, pydantic_field
|
||||
elif prop_type == "number":
|
||||
return float, pydantic_field
|
||||
elif prop_type == "null":
|
||||
return None, pydantic_field
|
||||
else:
|
||||
return Any, pydantic_field
|
||||
|
||||
|
||||
def get_model_fields(form_model_name, properties, required_fields):
|
||||
def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
|
||||
model_fields = {}
|
||||
|
||||
_model_cache: Dict[str, Type] = {}
|
||||
@ -123,21 +180,36 @@ def get_model_fields(form_model_name, properties, required_fields):
|
||||
for param_name, param_schema in properties.items():
|
||||
is_required = param_name in required_fields
|
||||
python_type_hint, pydantic_field_info = _process_schema_property(
|
||||
_model_cache, param_schema, form_model_name, param_name, is_required
|
||||
_model_cache,
|
||||
param_schema,
|
||||
form_model_name,
|
||||
param_name,
|
||||
is_required,
|
||||
schema_defs,
|
||||
)
|
||||
# Use the generated type hint and Field info
|
||||
model_fields[param_name] = (python_type_hint, pydantic_field_info)
|
||||
return model_fields
|
||||
|
||||
|
||||
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
if model_fields:
|
||||
FormModel = create_model(form_model_name, **model_fields)
|
||||
def get_tool_handler(
|
||||
session,
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields=None,
|
||||
):
|
||||
if form_model_fields:
|
||||
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
|
||||
ResponseModel = (
|
||||
create_model(f"{endpoint_name}_response_model", **response_model_fields)
|
||||
if response_model_fields
|
||||
else Any
|
||||
)
|
||||
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(form_data: FormModel):
|
||||
async def tool(form_data: FormModel) -> ResponseModel:
|
||||
args = form_data.model_dump(exclude_none=True)
|
||||
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
@ -158,7 +230,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
|
||||
except McpError as e:
|
||||
@ -206,7 +280,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
|
||||
except McpError as e:
|
||||
|
Loading…
Reference in New Issue
Block a user