mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge pull request #108 from open-webui/dev
Some checks failed
Release / release (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Some checks failed
Release / release (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
0.0.13
This commit is contained in:
commit
e37d0ebd27
16
CHANGELOG.md
16
CHANGELOG.md
@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on Keep a Changelog,
|
||||||
|
and this project adheres to Semantic Versioning.
|
||||||
|
|
||||||
|
## [0.0.13] - 2025-05-01
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🧪 **Support for Mixed and Union Types (anyOf/nullables)**: mcpo now accurately exposes OpenAPI schemas with anyOf compositions and nullable fields.
|
||||||
|
- 🧷 **Authentication-Required Docs Access with --strict-auth**: When enabled, the new --strict-auth option restricts access to both the tool endpoints and their interactive documentation pages—ensuring sensitive internal services aren’t inadvertently exposed to unauthenticated users or LLMs.
|
||||||
|
- 🧬 **Custom Schema Definitions for Complex Models**: Developers can now register custom BaseModel schemas with arbitrary nesting and field variants, allowing precise OpenAPI representations of deeply structured payloads—ensuring crystal-clear docs and compatibility for multi-layered data workflows.
|
||||||
|
- 🔄 **Smarter Schema Inference Across Data Types**: Schema generation has been enhanced to gracefully handle nested unions, nulls, and fallback types, dramatically improving accuracy in tools using variable output formats or flexible data contracts.
|
||||||
|
|
||||||
## [0.0.12] - 2025-04-14
|
## [0.0.12] - 2025-04-14
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "mcpo"
|
name = "mcpo"
|
||||||
version = "0.0.12"
|
version = "0.0.13"
|
||||||
description = "A simple, secure MCP-to-OpenAPI proxy server"
|
description = "A simple, secure MCP-to-OpenAPI proxy server"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
|
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
|
||||||
|
@ -26,6 +26,10 @@ def main(
|
|||||||
Optional[str],
|
Optional[str],
|
||||||
typer.Option("--api-key", "-k", help="API key for authentication"),
|
typer.Option("--api-key", "-k", help="API key for authentication"),
|
||||||
] = None,
|
] = None,
|
||||||
|
strict_auth: Annotated[
|
||||||
|
Optional[bool],
|
||||||
|
typer.Option("--strict-auth", help="API key protects all endpoints and documentation"),
|
||||||
|
] = False,
|
||||||
env: Annotated[
|
env: Annotated[
|
||||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||||
] = None,
|
] = None,
|
||||||
@ -116,6 +120,7 @@ def main(
|
|||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
strict_auth=strict_auth,
|
||||||
cors_allow_origins=cors_allow_origins,
|
cors_allow_origins=cors_allow_origins,
|
||||||
server_type=server_type,
|
server_type=server_type,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
|
@ -4,17 +4,16 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Depends
|
from fastapi import Depends, FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
|
||||||
from mcp import ClientSession, StdioServerParameters
|
|
||||||
from mcp.client.stdio import stdio_client
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||||
from mcpo.utils.auth import get_verify_api_key
|
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||||
|
|
||||||
|
|
||||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||||
@ -37,21 +36,31 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
|||||||
for tool in tools:
|
for tool in tools:
|
||||||
endpoint_name = tool.name
|
endpoint_name = tool.name
|
||||||
endpoint_description = tool.description
|
endpoint_description = tool.description
|
||||||
schema = tool.inputSchema
|
|
||||||
|
|
||||||
required_fields = schema.get("required", [])
|
inputSchema = tool.inputSchema
|
||||||
properties = schema.get("properties", {})
|
outputSchema = getattr(tool, "outputSchema", None)
|
||||||
|
|
||||||
form_model_name = f"{endpoint_name}_form_model"
|
form_model_fields = get_model_fields(
|
||||||
|
f"{endpoint_name}_form_model",
|
||||||
model_fields = get_model_fields(
|
inputSchema.get("properties", {}),
|
||||||
form_model_name,
|
inputSchema.get("required", []),
|
||||||
properties,
|
inputSchema.get("$defs", {}),
|
||||||
required_fields,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_model_fields = None
|
||||||
|
if outputSchema:
|
||||||
|
response_model_fields = get_model_fields(
|
||||||
|
f"{endpoint_name}_response_model",
|
||||||
|
outputSchema.get("properties", {}),
|
||||||
|
outputSchema.get("required", []),
|
||||||
|
outputSchema.get("$defs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
tool_handler = get_tool_handler(
|
tool_handler = get_tool_handler(
|
||||||
session, endpoint_name, form_model_name, model_fields
|
session,
|
||||||
|
endpoint_name,
|
||||||
|
form_model_fields,
|
||||||
|
response_model_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.post(
|
app.post(
|
||||||
@ -98,7 +107,10 @@ async def lifespan(app: FastAPI):
|
|||||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||||
yield
|
yield
|
||||||
if server_type == "sse":
|
if server_type == "sse":
|
||||||
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
|
async with sse_client(url=args[0], sse_read_timeout=None) as (
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
):
|
||||||
async with ClientSession(reader, writer) as session:
|
async with ClientSession(reader, writer) as session:
|
||||||
app.state.session = session
|
app.state.session = session
|
||||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||||
@ -114,6 +126,7 @@ async def run(
|
|||||||
):
|
):
|
||||||
# Server API Key
|
# Server API Key
|
||||||
api_dependency = get_verify_api_key(api_key) if api_key else None
|
api_dependency = get_verify_api_key(api_key) if api_key else None
|
||||||
|
strict_auth = kwargs.get("strict_auth", False)
|
||||||
|
|
||||||
# MCP Server
|
# MCP Server
|
||||||
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
|
||||||
@ -132,7 +145,6 @@ async def run(
|
|||||||
ssl_certfile = kwargs.get("ssl_certfile")
|
ssl_certfile = kwargs.get("ssl_certfile")
|
||||||
ssl_keyfile = kwargs.get("ssl_keyfile")
|
ssl_keyfile = kwargs.get("ssl_keyfile")
|
||||||
path_prefix = kwargs.get("path_prefix") or "/"
|
path_prefix = kwargs.get("path_prefix") or "/"
|
||||||
|
|
||||||
main_app = FastAPI(
|
main_app = FastAPI(
|
||||||
title=name,
|
title=name,
|
||||||
description=description,
|
description=description,
|
||||||
@ -150,6 +162,10 @@ async def run(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add middleware to protect also documentation and spec
|
||||||
|
if api_key and strict_auth:
|
||||||
|
main_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||||
|
|
||||||
if server_type == "sse":
|
if server_type == "sse":
|
||||||
main_app.state.server_type = "sse"
|
main_app.state.server_type = "sse"
|
||||||
main_app.state.args = server_command[0]
|
main_app.state.args = server_command[0]
|
||||||
@ -197,6 +213,10 @@ async def run(
|
|||||||
sub_app.state.server_type = "sse"
|
sub_app.state.server_type = "sse"
|
||||||
sub_app.state.args = server_cfg["url"]
|
sub_app.state.args = server_cfg["url"]
|
||||||
|
|
||||||
|
# Add middleware to protect also documentation and spec
|
||||||
|
if api_key and strict_auth:
|
||||||
|
sub_app.add_middleware(APIKeyMiddleware, api_key=api_key)
|
||||||
|
|
||||||
sub_app.state.api_dependency = api_dependency
|
sub_app.state.api_dependency = api_dependency
|
||||||
|
|
||||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict, Union
|
||||||
|
|
||||||
from mcpo.utils.main import _process_schema_property
|
from mcpo.utils.main import _process_schema_property
|
||||||
|
|
||||||
@ -64,6 +64,17 @@ def test_process_simple_number():
|
|||||||
assert result_field.default == expected_field.default
|
assert result_field.default == expected_field.default
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_null():
|
||||||
|
schema = {"type": "null"}
|
||||||
|
expected_type = None
|
||||||
|
expected_field = Field(default=..., description="")
|
||||||
|
result_type, result_field = _process_schema_property(
|
||||||
|
_model_cache, schema, "test", "prop", True
|
||||||
|
)
|
||||||
|
assert result_type == expected_type
|
||||||
|
assert result_field.default == expected_field.default
|
||||||
|
|
||||||
|
|
||||||
def test_process_unknown_type():
|
def test_process_unknown_type():
|
||||||
schema = {"type": "unknown"}
|
schema = {"type": "unknown"}
|
||||||
expected_type = Any
|
expected_type = Any
|
||||||
@ -224,3 +235,78 @@ def test_model_caching():
|
|||||||
)
|
)
|
||||||
assert result_type3 == result_type1 # Should be the same cached object
|
assert result_type3 == result_type1 # Should be the same cached object
|
||||||
assert len(_model_cache) == 2 # Only two unique models created
|
assert len(_model_cache) == 2 # Only two unique models created
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_type_property_with_list():
|
||||||
|
schema = {
|
||||||
|
"type": ["string", "number"],
|
||||||
|
"description": "A property with multiple types",
|
||||||
|
}
|
||||||
|
expected_field = Field(default=..., description="A property with multiple types")
|
||||||
|
result_type, result_field = _process_schema_property(
|
||||||
|
_model_cache, schema, "test", "multi_type", True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the resulting type is a Union
|
||||||
|
assert str(result_type).startswith("typing.Union[")
|
||||||
|
|
||||||
|
# Check if both types are in the Union
|
||||||
|
assert str in result_type.__args__
|
||||||
|
assert float in result_type.__args__
|
||||||
|
|
||||||
|
# Check field properties
|
||||||
|
assert result_field.default == expected_field.default
|
||||||
|
assert result_field.description == expected_field.description
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_type_property_with_any_of():
|
||||||
|
schema = {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to call",
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The arguments to pass to the function, as a JSON string",
|
||||||
|
"default": "{}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["name"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"function_id": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "The id of the function to call",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["function_id"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["auto", "none"],
|
||||||
|
"description": "Control function calling behavior",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"description": "A property with multiple types",
|
||||||
|
}
|
||||||
|
result_type, result_field = _process_schema_property(
|
||||||
|
_model_cache, schema, "test", "multi_type", True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the resulting type is a Union
|
||||||
|
assert result_type.__origin__ == Union
|
||||||
|
|
||||||
|
# Check if the Union has the correct number of types
|
||||||
|
assert len(result_type.__args__) == 3
|
||||||
|
assert len(result_type.__args__[0].model_fields) == 2
|
||||||
|
assert len(result_type.__args__[1].model_fields) == 1
|
||||||
|
assert result_type.__args__[2] is str
|
||||||
|
|
||||||
|
# assert result_field parameter config
|
||||||
|
assert result_field.description == "A property with multiple types"
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
import base64
|
||||||
|
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
@ -33,6 +36,75 @@ def get_verify_api_key(api_key: str):
|
|||||||
return verify_api_key
|
return verify_api_key
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
Middleware that enforces Basic or Bearer token authentication for all requests.
|
||||||
|
"""
|
||||||
|
def __init__(self, app, api_key: str):
|
||||||
|
super().__init__(app)
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Skip authentication for OPTIONS requests
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Get authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
|
||||||
|
# Verify API key
|
||||||
|
try:
|
||||||
|
# Use the same function that the dependency uses
|
||||||
|
if not authorization:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Missing or invalid Authorization header"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle Bearer token auth
|
||||||
|
if authorization.startswith("Bearer "):
|
||||||
|
token = authorization[7:] # Remove "Bearer " prefix
|
||||||
|
if token != self.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Invalid API key"}
|
||||||
|
)
|
||||||
|
# Handle Basic auth
|
||||||
|
elif authorization.startswith("Basic "):
|
||||||
|
# Decode the base64 credentials
|
||||||
|
credentials = authorization[6:] # Remove "Basic " prefix
|
||||||
|
try:
|
||||||
|
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||||
|
# Basic auth format is username:password
|
||||||
|
username, password = decoded.split(':', 1)
|
||||||
|
# Any username is allowed, but password must match api_key
|
||||||
|
if password != self.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Invalid credentials"}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Invalid Basic Authentication format"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Unsupported authorization method"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer, Basic"}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
# def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||||
# payload = data.copy()
|
# payload = data.copy()
|
||||||
|
|
||||||
|
@ -1,12 +1,22 @@
|
|||||||
from typing import Any, Dict, List, Type, Union, ForwardRef
|
import json
|
||||||
from pydantic import create_model, Field
|
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from mcp import ClientSession, types
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
|
|
||||||
|
from mcp import ClientSession, types
|
||||||
|
from mcp.types import (
|
||||||
|
CallToolResult,
|
||||||
|
PARSE_ERROR,
|
||||||
|
INVALID_REQUEST,
|
||||||
|
METHOD_NOT_FOUND,
|
||||||
|
INVALID_PARAMS,
|
||||||
|
INTERNAL_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
|
|
||||||
import json
|
from pydantic import Field, create_model
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
MCP_ERROR_TO_HTTP_STATUS = {
|
MCP_ERROR_TO_HTTP_STATUS = {
|
||||||
PARSE_ERROR: 400,
|
PARSE_ERROR: 400,
|
||||||
@ -44,6 +54,7 @@ def _process_schema_property(
|
|||||||
model_name_prefix: str,
|
model_name_prefix: str,
|
||||||
prop_name: str,
|
prop_name: str,
|
||||||
is_required: bool,
|
is_required: bool,
|
||||||
|
schema_defs: Optional[Dict] = None,
|
||||||
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
|
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
|
||||||
"""
|
"""
|
||||||
Recursively processes a schema property to determine its Python type hint
|
Recursively processes a schema property to determine its Python type hint
|
||||||
@ -53,11 +64,49 @@ def _process_schema_property(
|
|||||||
A tuple containing (python_type_hint, pydantic_field).
|
A tuple containing (python_type_hint, pydantic_field).
|
||||||
The pydantic_field contains default value and description.
|
The pydantic_field contains default value and description.
|
||||||
"""
|
"""
|
||||||
|
if "$ref" in prop_schema:
|
||||||
|
ref = prop_schema["$ref"]
|
||||||
|
ref = ref.split("/")[-1]
|
||||||
|
assert ref in schema_defs, "Custom field not found"
|
||||||
|
prop_schema = schema_defs[ref]
|
||||||
|
|
||||||
prop_type = prop_schema.get("type")
|
prop_type = prop_schema.get("type")
|
||||||
prop_desc = prop_schema.get("description", "")
|
prop_desc = prop_schema.get("description", "")
|
||||||
|
|
||||||
default_value = ... if is_required else prop_schema.get("default", None)
|
default_value = ... if is_required else prop_schema.get("default", None)
|
||||||
pydantic_field = Field(default=default_value, description=prop_desc)
|
pydantic_field = Field(default=default_value, description=prop_desc)
|
||||||
|
|
||||||
|
# Handle the case where prop_type is missing but 'anyOf' key exists
|
||||||
|
# In this case, use data type from 'anyOf' to determine the type hint
|
||||||
|
if "anyOf" in prop_schema:
|
||||||
|
type_hints = []
|
||||||
|
for i, schema_option in enumerate(prop_schema["anyOf"]):
|
||||||
|
type_hint, _ = _process_schema_property(
|
||||||
|
_model_cache,
|
||||||
|
schema_option,
|
||||||
|
f"{model_name_prefix}_{prop_name}",
|
||||||
|
f"choice_{i}",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
type_hints.append(type_hint)
|
||||||
|
return Union[tuple(type_hints)], pydantic_field
|
||||||
|
|
||||||
|
# Handle the case where prop_type is a list of types, e.g. ['string', 'number']
|
||||||
|
if isinstance(prop_type, list):
|
||||||
|
# Create a Union of all the types
|
||||||
|
type_hints = []
|
||||||
|
for type_option in prop_type:
|
||||||
|
# Create a temporary schema with the single type and process it
|
||||||
|
temp_schema = dict(prop_schema)
|
||||||
|
temp_schema["type"] = type_option
|
||||||
|
type_hint, _ = _process_schema_property(
|
||||||
|
_model_cache, temp_schema, model_name_prefix, prop_name, False
|
||||||
|
)
|
||||||
|
type_hints.append(type_hint)
|
||||||
|
|
||||||
|
# Return a Union of all possible types
|
||||||
|
return Union[tuple(type_hints)], pydantic_field
|
||||||
|
|
||||||
if prop_type == "object":
|
if prop_type == "object":
|
||||||
nested_properties = prop_schema.get("properties", {})
|
nested_properties = prop_schema.get("properties", {})
|
||||||
nested_required = prop_schema.get("required", [])
|
nested_required = prop_schema.get("required", [])
|
||||||
@ -73,7 +122,12 @@ def _process_schema_property(
|
|||||||
for name, schema in nested_properties.items():
|
for name, schema in nested_properties.items():
|
||||||
is_nested_required = name in nested_required
|
is_nested_required = name in nested_required
|
||||||
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
||||||
_model_cache, schema, nested_model_name, name, is_nested_required
|
_model_cache,
|
||||||
|
schema,
|
||||||
|
nested_model_name,
|
||||||
|
name,
|
||||||
|
is_nested_required,
|
||||||
|
schema_defs,
|
||||||
)
|
)
|
||||||
|
|
||||||
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
||||||
@ -98,7 +152,8 @@ def _process_schema_property(
|
|||||||
items_schema,
|
items_schema,
|
||||||
f"{model_name_prefix}_{prop_name}",
|
f"{model_name_prefix}_{prop_name}",
|
||||||
"item",
|
"item",
|
||||||
False, # Items aren't required at this level
|
False, # Items aren't required at this level,
|
||||||
|
schema_defs,
|
||||||
)
|
)
|
||||||
list_type_hint = List[item_type_hint]
|
list_type_hint = List[item_type_hint]
|
||||||
return list_type_hint, pydantic_field
|
return list_type_hint, pydantic_field
|
||||||
@ -111,11 +166,13 @@ def _process_schema_property(
|
|||||||
return bool, pydantic_field
|
return bool, pydantic_field
|
||||||
elif prop_type == "number":
|
elif prop_type == "number":
|
||||||
return float, pydantic_field
|
return float, pydantic_field
|
||||||
|
elif prop_type == "null":
|
||||||
|
return None, pydantic_field
|
||||||
else:
|
else:
|
||||||
return Any, pydantic_field
|
return Any, pydantic_field
|
||||||
|
|
||||||
|
|
||||||
def get_model_fields(form_model_name, properties, required_fields):
|
def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
|
||||||
model_fields = {}
|
model_fields = {}
|
||||||
|
|
||||||
_model_cache: Dict[str, Type] = {}
|
_model_cache: Dict[str, Type] = {}
|
||||||
@ -123,21 +180,36 @@ def get_model_fields(form_model_name, properties, required_fields):
|
|||||||
for param_name, param_schema in properties.items():
|
for param_name, param_schema in properties.items():
|
||||||
is_required = param_name in required_fields
|
is_required = param_name in required_fields
|
||||||
python_type_hint, pydantic_field_info = _process_schema_property(
|
python_type_hint, pydantic_field_info = _process_schema_property(
|
||||||
_model_cache, param_schema, form_model_name, param_name, is_required
|
_model_cache,
|
||||||
|
param_schema,
|
||||||
|
form_model_name,
|
||||||
|
param_name,
|
||||||
|
is_required,
|
||||||
|
schema_defs,
|
||||||
)
|
)
|
||||||
# Use the generated type hint and Field info
|
# Use the generated type hint and Field info
|
||||||
model_fields[param_name] = (python_type_hint, pydantic_field_info)
|
model_fields[param_name] = (python_type_hint, pydantic_field_info)
|
||||||
return model_fields
|
return model_fields
|
||||||
|
|
||||||
|
|
||||||
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
def get_tool_handler(
|
||||||
if model_fields:
|
session,
|
||||||
FormModel = create_model(form_model_name, **model_fields)
|
endpoint_name,
|
||||||
|
form_model_fields,
|
||||||
|
response_model_fields=None,
|
||||||
|
):
|
||||||
|
if form_model_fields:
|
||||||
|
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
|
||||||
|
ResponseModel = (
|
||||||
|
create_model(f"{endpoint_name}_response_model", **response_model_fields)
|
||||||
|
if response_model_fields
|
||||||
|
else Any
|
||||||
|
)
|
||||||
|
|
||||||
def make_endpoint_func(
|
def make_endpoint_func(
|
||||||
endpoint_name: str, FormModel, session: ClientSession
|
endpoint_name: str, FormModel, session: ClientSession
|
||||||
): # Parameterized endpoint
|
): # Parameterized endpoint
|
||||||
async def tool(form_data: FormModel):
|
async def tool(form_data: FormModel) -> ResponseModel:
|
||||||
args = form_data.model_dump(exclude_none=True)
|
args = form_data.model_dump(exclude_none=True)
|
||||||
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||||
try:
|
try:
|
||||||
@ -158,7 +230,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response_data = process_tool_response(result)
|
response_data = process_tool_response(result)
|
||||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
final_response = (
|
||||||
|
response_data[0] if len(response_data) == 1 else response_data
|
||||||
|
)
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
except McpError as e:
|
except McpError as e:
|
||||||
@ -206,7 +280,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response_data = process_tool_response(result)
|
response_data = process_tool_response(result)
|
||||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
final_response = (
|
||||||
|
response_data[0] if len(response_data) == 1 else response_data
|
||||||
|
)
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
except McpError as e:
|
except McpError as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user