mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
12
CHANGELOG.md
12
CHANGELOG.md
@@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.0.10] - 2025-04-10
|
||||
|
||||
### Added
|
||||
|
||||
- 📦 **Support for --env-path to Load Environment Variables from File**: Use the new --env-path flag to securely pass environment variables via a .env-style file—making it easier than ever to manage secrets and config without cluttering your CLI or exposing sensitive data.
|
||||
- 🧪 **Enhanced Support for Nested Object and Array Types in OpenAPI Schema**: Tools with complex input/output structures (e.g., JSON payloads with arrays or nested fields) are now correctly interpreted and exposed with accurate OpenAPI documentation—making form-based testing in the UI smoother and integrations far more predictable.
|
||||
- 🛑 **Smart HTTP Exceptions for Better Debugging**: Clear, structured HTTP error responses are now automatically returned for bad requests or internal tool errors—helping users immediately understand what went wrong without digging through raw traces.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🪛 **Fixed --env Flag Behavior for Inline Environment Variables**: Resolved issues where the --env CLI flag silently failed or misbehaved—environment injection is now consistent and reliable whether passed inline with --env or via --env-path.
|
||||
|
||||
## [0.0.9] - 2025-04-06
|
||||
|
||||
### Added
|
||||
|
||||
20
README.md
20
README.md
@@ -94,6 +94,26 @@ Each with a dedicated OpenAPI schema and proxy handler. Access full schema UI at
|
||||
- Python 3.8+
|
||||
- uv (optional, but highly recommended for performance + packaging)
|
||||
|
||||
## 🛠️ Development & Testing
|
||||
|
||||
To contribute or run tests locally:
|
||||
|
||||
1. **Set up the environment:**
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/open-webui/mcpo.git
|
||||
cd mcpo
|
||||
|
||||
# Install dependencies (including dev dependencies)
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
2. **Run tests:**
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
|
||||
## 🪪 License
|
||||
|
||||
MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mcpo"
|
||||
version = "0.0.9"
|
||||
version = "0.0.10"
|
||||
description = "A simple, secure MCP-to-OpenAPI proxy server"
|
||||
authors = [
|
||||
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"pydantic>=2.11.1",
|
||||
"pyjwt[crypto]>=2.10.1",
|
||||
"python-dotenv>=1.1.0",
|
||||
"typer>=0.15.2",
|
||||
"uvicorn>=0.34.0",
|
||||
]
|
||||
@@ -24,3 +25,8 @@ mcpo = "mcpo:app"
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.3.5",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
import asyncio
|
||||
import typer
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from typing_extensions import Annotated
|
||||
from typing import Optional, List
|
||||
@@ -29,6 +29,10 @@ def main(
|
||||
env: Annotated[
|
||||
Optional[List[str]], typer.Option("--env", "-e", help="Environment variables")
|
||||
] = None,
|
||||
env_path: Annotated[
|
||||
Optional[str],
|
||||
typer.Option("--env-path", help="Path to environment variables file"),
|
||||
] = None,
|
||||
config: Annotated[
|
||||
Optional[str], typer.Option("--config", "-c", help="Config file path")
|
||||
] = None,
|
||||
@@ -74,15 +78,23 @@ def main(
|
||||
f"Starting MCP OpenAPI Proxy on {host}:{port} with command: {' '.join(server_command)}"
|
||||
)
|
||||
|
||||
env_dict = {}
|
||||
if env:
|
||||
for var in env:
|
||||
key, value = env.split("=", 1)
|
||||
env_dict[key] = value
|
||||
try:
|
||||
env_dict = {}
|
||||
if env:
|
||||
for var in env:
|
||||
key, value = var.split("=", 1)
|
||||
env_dict[key] = value
|
||||
|
||||
# Set environment variables
|
||||
for key, value in env_dict.items():
|
||||
os.environ[key] = value
|
||||
if env_path:
|
||||
# Load environment variables from the specified file
|
||||
load_dotenv(env_path)
|
||||
env_dict.update(dict(os.environ))
|
||||
|
||||
# Set environment variables
|
||||
for key, value in env_dict.items():
|
||||
os.environ[key] = value
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Whatever the prefix is, make sure it starts and ends with a /
|
||||
if path_prefix is None:
|
||||
|
||||
110
src/mcpo/main.py
110
src/mcpo/main.py
@@ -1,61 +1,24 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Body, Depends
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
from pydantic import create_model
|
||||
from starlette.routing import Mount
|
||||
|
||||
|
||||
def get_python_type(param_type: str):
|
||||
if param_type == "string":
|
||||
return str
|
||||
elif param_type == "integer":
|
||||
return int
|
||||
elif param_type == "boolean":
|
||||
return bool
|
||||
elif param_type == "number":
|
||||
return float
|
||||
elif param_type == "object":
|
||||
return Dict[str, Any]
|
||||
elif param_type == "array":
|
||||
return list
|
||||
else:
|
||||
return str # Fallback
|
||||
# Expand as needed. PRs welcome!
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
def process_tool_response(result: CallToolResult) -> list:
|
||||
"""Universal response processor for all tool endpoints"""
|
||||
response = []
|
||||
for content in result.content:
|
||||
if isinstance(content, types.TextContent):
|
||||
text = content.text
|
||||
if isinstance(text, str):
|
||||
try:
|
||||
text = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
response.append(text)
|
||||
elif isinstance(content, types.ImageContent):
|
||||
image_data = f"data:{content.mimeType};base64,{content.data}"
|
||||
response.append(image_data)
|
||||
elif isinstance(content, types.EmbeddedResource):
|
||||
# TODO: Handle embedded resources
|
||||
response.append("Embedded resource not supported yet.")
|
||||
return response
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
session = app.state.session
|
||||
session: ClientSession = app.state.session
|
||||
if not session:
|
||||
raise ValueError("Session is not initialized in the app state.")
|
||||
|
||||
@@ -76,53 +39,26 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
endpoint_description = tool.description
|
||||
schema = tool.inputSchema
|
||||
|
||||
model_fields = {}
|
||||
required_fields = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for param_name, param_schema in properties.items():
|
||||
param_type = param_schema.get("type", "string")
|
||||
param_desc = param_schema.get("description", "")
|
||||
python_type = get_python_type(param_type)
|
||||
default_value = ... if param_name in required_fields else None
|
||||
model_fields[param_name] = (
|
||||
python_type,
|
||||
Body(default_value, description=param_desc),
|
||||
)
|
||||
form_model_name = f"{endpoint_name}_form_model"
|
||||
|
||||
if model_fields:
|
||||
FormModel = create_model(f"{endpoint_name}_form_model", **model_fields)
|
||||
model_fields = get_model_fields(
|
||||
form_model_name,
|
||||
properties,
|
||||
required_fields,
|
||||
)
|
||||
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(form_data: FormModel):
|
||||
args = form_data.model_dump(exclude_none=True)
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
return process_tool_response(result)
|
||||
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
|
||||
else:
|
||||
|
||||
def make_endpoint_func_no_args(
|
||||
endpoint_name: str, session: ClientSession
|
||||
): # Parameterless endpoint
|
||||
async def tool(): # No parameters
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
) # Empty dict
|
||||
return process_tool_response(result) # Same processor
|
||||
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func_no_args(endpoint_name, session)
|
||||
tool_handler = get_tool_handler(
|
||||
session, endpoint_name, form_model_name, model_fields
|
||||
)
|
||||
|
||||
app.post(
|
||||
f"/{endpoint_name}",
|
||||
summary=endpoint_name.replace("_", " ").title(),
|
||||
description=endpoint_description,
|
||||
response_model_exclude_none=True,
|
||||
dependencies=[Depends(api_dependency)] if api_dependency else [],
|
||||
)(tool_handler)
|
||||
|
||||
@@ -171,11 +107,13 @@ async def run(
|
||||
# MCP Config
|
||||
config_path = kwargs.get("config")
|
||||
server_command = kwargs.get("server_command")
|
||||
|
||||
name = kwargs.get("name") or "MCP OpenAPI Proxy"
|
||||
description = (
|
||||
kwargs.get("description") or "Automatically generated API from MCP Tool Schemas"
|
||||
)
|
||||
version = kwargs.get("version") or "1.0"
|
||||
|
||||
ssl_certfile = kwargs.get("ssl_certfile")
|
||||
ssl_keyfile = kwargs.get("ssl_keyfile")
|
||||
path_prefix = kwargs.get("path_prefix") or "/"
|
||||
@@ -198,7 +136,6 @@ async def run(
|
||||
)
|
||||
|
||||
if server_command:
|
||||
|
||||
main_app.state.command = server_command[0]
|
||||
main_app.state.args = server_command[1:]
|
||||
main_app.state.env = os.environ.copy()
|
||||
@@ -207,14 +144,16 @@ async def run(
|
||||
elif config_path:
|
||||
with open(config_path, "r") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
mcp_servers = config_data.get("mcpServers", {})
|
||||
if not mcp_servers:
|
||||
raise ValueError("No 'mcpServers' found in config file.")
|
||||
|
||||
main_app.description += "\n\n- **available tools**:"
|
||||
for server_name, server_cfg in mcp_servers.items():
|
||||
sub_app = FastAPI(
|
||||
title=f"{server_name}",
|
||||
description=f"{server_name} MCP Server\n\n- [back to tool list](http://{host}:{port}/docs)",
|
||||
description=f"{server_name} MCP Server\n\n- [back to tool list](/docs)",
|
||||
version="1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@@ -232,10 +171,9 @@ async def run(
|
||||
sub_app.state.env = {**os.environ, **server_cfg.get("env", {})}
|
||||
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
|
||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||
main_app.description += (
|
||||
f"\n - [{server_name}](http://{host}:{port}/{server_name}/docs)"
|
||||
)
|
||||
main_app.description += f"\n - [{server_name}](/{server_name}/docs)"
|
||||
else:
|
||||
raise ValueError("You must provide either server_command or config.")
|
||||
|
||||
|
||||
235
src/mcpo/utils/main.py
Normal file
235
src/mcpo/utils/main.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Any, Dict, List, Type, Union, ForwardRef
|
||||
from pydantic import create_model, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from mcp import ClientSession, types
|
||||
from fastapi import HTTPException
|
||||
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
import json
|
||||
|
||||
MCP_ERROR_TO_HTTP_STATUS = {
|
||||
PARSE_ERROR: 400,
|
||||
INVALID_REQUEST: 400,
|
||||
METHOD_NOT_FOUND: 404,
|
||||
INVALID_PARAMS: 422,
|
||||
INTERNAL_ERROR: 500,
|
||||
}
|
||||
|
||||
|
||||
def process_tool_response(result: CallToolResult) -> list:
|
||||
"""Universal response processor for all tool endpoints"""
|
||||
response = []
|
||||
for content in result.content:
|
||||
if isinstance(content, types.TextContent):
|
||||
text = content.text
|
||||
if isinstance(text, str):
|
||||
try:
|
||||
text = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
response.append(text)
|
||||
elif isinstance(content, types.ImageContent):
|
||||
image_data = f"data:{content.mimeType};base64,{content.data}"
|
||||
response.append(image_data)
|
||||
elif isinstance(content, types.EmbeddedResource):
|
||||
# TODO: Handle embedded resources
|
||||
response.append("Embedded resource not supported yet.")
|
||||
return response
|
||||
|
||||
|
||||
def _process_schema_property(
|
||||
_model_cache: Dict[str, Type],
|
||||
prop_schema: Dict[str, Any],
|
||||
model_name_prefix: str,
|
||||
prop_name: str,
|
||||
is_required: bool,
|
||||
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
|
||||
"""
|
||||
Recursively processes a schema property to determine its Python type hint
|
||||
and Pydantic Field definition.
|
||||
|
||||
Returns:
|
||||
A tuple containing (python_type_hint, pydantic_field).
|
||||
The pydantic_field contains default value and description.
|
||||
"""
|
||||
prop_type = prop_schema.get("type")
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
default_value = ... if is_required else prop_schema.get("default", None)
|
||||
pydantic_field = Field(default=default_value, description=prop_desc)
|
||||
|
||||
if prop_type == "object":
|
||||
nested_properties = prop_schema.get("properties", {})
|
||||
nested_required = prop_schema.get("required", [])
|
||||
nested_fields = {}
|
||||
|
||||
nested_model_name = f"{model_name_prefix}_{prop_name}_model".replace(
|
||||
"__", "_"
|
||||
).rstrip("_")
|
||||
|
||||
if nested_model_name in _model_cache:
|
||||
return _model_cache[nested_model_name], pydantic_field
|
||||
|
||||
for name, schema in nested_properties.items():
|
||||
is_nested_required = name in nested_required
|
||||
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
||||
_model_cache, schema, nested_model_name, name, is_nested_required
|
||||
)
|
||||
|
||||
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
||||
|
||||
if not nested_fields:
|
||||
return Dict[str, Any], pydantic_field
|
||||
|
||||
NestedModel = create_model(nested_model_name, **nested_fields)
|
||||
_model_cache[nested_model_name] = NestedModel
|
||||
|
||||
return NestedModel, pydantic_field
|
||||
|
||||
elif prop_type == "array":
|
||||
items_schema = prop_schema.get("items")
|
||||
if not items_schema:
|
||||
# Default to list of anything if items schema is missing
|
||||
return List[Any], pydantic_field
|
||||
|
||||
# Recursively determine the type of items in the array
|
||||
item_type_hint, _ = _process_schema_property(
|
||||
_model_cache,
|
||||
items_schema,
|
||||
f"{model_name_prefix}_{prop_name}",
|
||||
"item",
|
||||
False, # Items aren't required at this level
|
||||
)
|
||||
list_type_hint = List[item_type_hint]
|
||||
return list_type_hint, pydantic_field
|
||||
|
||||
elif prop_type == "string":
|
||||
return str, pydantic_field
|
||||
elif prop_type == "integer":
|
||||
return int, pydantic_field
|
||||
elif prop_type == "boolean":
|
||||
return bool, pydantic_field
|
||||
elif prop_type == "number":
|
||||
return float, pydantic_field
|
||||
else:
|
||||
return Any, pydantic_field
|
||||
|
||||
|
||||
def get_model_fields(form_model_name, properties, required_fields):
|
||||
model_fields = {}
|
||||
|
||||
_model_cache: Dict[str, Type] = {}
|
||||
|
||||
for param_name, param_schema in properties.items():
|
||||
is_required = param_name in required_fields
|
||||
python_type_hint, pydantic_field_info = _process_schema_property(
|
||||
_model_cache, param_schema, form_model_name, param_name, is_required
|
||||
)
|
||||
# Use the generated type hint and Field info
|
||||
model_fields[param_name] = (python_type_hint, pydantic_field_info)
|
||||
return model_fields
|
||||
|
||||
|
||||
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
if model_fields:
|
||||
FormModel = create_model(form_model_name, **model_fields)
|
||||
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(form_data: FormModel):
|
||||
args = form_data.model_dump(exclude_none=True)
|
||||
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
error_data = None # Initialize error_data
|
||||
if result.content:
|
||||
if isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
if error_data is not None:
|
||||
detail["data"] = error_data
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
return final_response
|
||||
|
||||
except McpError as e:
|
||||
print(f"MCP Error calling {endpoint_name}: {e.error}")
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error calling {endpoint_name}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
|
||||
else:
|
||||
|
||||
def make_endpoint_func_no_args(
|
||||
endpoint_name: str, session: ClientSession
|
||||
): # Parameterless endpoint
|
||||
async def tool(): # No parameters
|
||||
print(f"Calling endpoint: {endpoint_name}, with no args")
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
) # Empty dict
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
if result.content:
|
||||
if isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
return final_response
|
||||
|
||||
except McpError as e:
|
||||
print(f"MCP Error calling {endpoint_name}: {e.error}")
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
# Propagate the error received from MCP as an HTTP exception
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error calling {endpoint_name}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func_no_args(endpoint_name, session)
|
||||
|
||||
return tool_handler
|
||||
226
tests/test_main.py
Normal file
226
tests/test_main.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from src.mcpo.utils.main import _process_schema_property
|
||||
|
||||
|
||||
_model_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_model_cache():
|
||||
_model_cache.clear()
|
||||
yield
|
||||
_model_cache.clear()
|
||||
|
||||
|
||||
def test_process_simple_string_required():
|
||||
schema = {"type": "string", "description": "A simple string"}
|
||||
expected_type = str
|
||||
expected_field = Field(default=..., description="A simple string")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
|
||||
|
||||
def test_process_simple_integer_optional():
|
||||
schema = {"type": "integer", "default": 10}
|
||||
expected_type = int
|
||||
expected_field = Field(default=10, description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
|
||||
|
||||
def test_process_simple_boolean_optional_no_default():
|
||||
schema = {"type": "boolean"}
|
||||
expected_type = bool
|
||||
expected_field = Field(
|
||||
default=None, description=""
|
||||
) # Default is None if not required and no default specified
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
|
||||
|
||||
def test_process_simple_number():
|
||||
schema = {"type": "number"}
|
||||
expected_type = float
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_unknown_type():
|
||||
schema = {"type": "unknown"}
|
||||
expected_type = Any
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_array_of_strings():
|
||||
schema = {"type": "array", "items": {"type": "string"}}
|
||||
expected_type = List[str]
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_array_of_any_missing_items():
|
||||
schema = {"type": "array"} # Missing "items"
|
||||
expected_type = List[Any]
|
||||
expected_field = Field(default=None, description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_process_simple_object():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
assert issubclass(result_type, BaseModel) # Check if it's a Pydantic model
|
||||
|
||||
# Check fields of the generated model
|
||||
model_fields = result_type.model_fields
|
||||
assert "name" in model_fields
|
||||
assert model_fields["name"].annotation is str
|
||||
assert model_fields["name"].is_required()
|
||||
|
||||
assert "age" in model_fields
|
||||
assert model_fields["age"].annotation is int
|
||||
assert not model_fields["age"].is_required()
|
||||
assert model_fields["age"].default is None # Optional without default
|
||||
|
||||
|
||||
def test_process_nested_object():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
},
|
||||
"required": ["user"],
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "outer_prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert issubclass(result_type, BaseModel)
|
||||
|
||||
outer_model_fields = result_type.model_fields
|
||||
assert "user" in outer_model_fields
|
||||
assert outer_model_fields["user"].is_required()
|
||||
|
||||
nested_model_type = outer_model_fields["user"].annotation
|
||||
assert issubclass(nested_model_type, BaseModel)
|
||||
|
||||
nested_model_fields = nested_model_type.model_fields
|
||||
assert "id" in nested_model_fields
|
||||
assert nested_model_fields["id"].annotation is int
|
||||
assert nested_model_fields["id"].is_required()
|
||||
|
||||
|
||||
def test_process_array_of_objects():
|
||||
schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"item_id": {"type": "string"}},
|
||||
"required": ["item_id"],
|
||||
},
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert str(result_type).startswith("typing.List[") # Check it's a List
|
||||
|
||||
# Get the inner type from List[...]
|
||||
item_model_type = result_type.__args__[0]
|
||||
assert issubclass(item_model_type, BaseModel)
|
||||
|
||||
item_model_fields = item_model_type.model_fields
|
||||
assert "item_id" in item_model_fields
|
||||
assert item_model_fields["item_id"].annotation is str
|
||||
assert item_model_fields["item_id"].is_required()
|
||||
|
||||
|
||||
def test_process_empty_object():
|
||||
schema = {"type": "object", "properties": {}}
|
||||
expected_type = Dict[str, Any] # Should default to Dict[str, Any] if no properties
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
|
||||
def test_model_caching():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
# First call
|
||||
result_type1, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj1", True
|
||||
)
|
||||
model_name = "cache_test_obj1_model"
|
||||
assert model_name in _model_cache
|
||||
assert _model_cache[model_name] == result_type1
|
||||
|
||||
# Second call with same structure but different prefix/prop name (should generate new)
|
||||
result_type2, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj2", True
|
||||
)
|
||||
model_name2 = "cache_test_obj2_model"
|
||||
assert model_name2 in _model_cache
|
||||
assert _model_cache[model_name2] == result_type2
|
||||
assert result_type1 != result_type2 # Different models
|
||||
|
||||
# Third call identical to the first (should return cached model)
|
||||
result_type3, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj1", True
|
||||
)
|
||||
assert result_type3 == result_type1 # Should be the same cached object
|
||||
assert len(_model_cache) == 2 # Only two unique models created
|
||||
52
uv.lock
generated
52
uv.lock
generated
@@ -261,6 +261,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
@@ -303,10 +312,16 @@ dependencies = [
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "typer" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "click", specifier = ">=8.1.8" },
|
||||
@@ -315,10 +330,14 @@ requires-dist = [
|
||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
||||
{ name = "pydantic", specifier = ">=2.11.1" },
|
||||
{ name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" },
|
||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||
{ name = "typer", specifier = ">=0.15.2" },
|
||||
{ name = "uvicorn", specifier = ">=0.34.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "pytest", specifier = ">=8.3.5" }]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -328,6 +347,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "passlib"
|
||||
version = "1.7.4"
|
||||
@@ -342,6 +370,15 @@ bcrypt = [
|
||||
{ name = "bcrypt" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
@@ -467,6 +504,21 @@ crypto = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.0"
|
||||
|
||||
Reference in New Issue
Block a user