mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge pull request #105 from JinY0ung-Shin/main
feat: support cusotm BaseModel "$def" and Output schema
This commit is contained in:
@@ -4,15 +4,14 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from starlette.routing import Mount
|
||||
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
|
||||
@@ -37,21 +36,32 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
for tool in tools:
|
||||
endpoint_name = tool.name
|
||||
endpoint_description = tool.description
|
||||
schema = tool.inputSchema
|
||||
|
||||
required_fields = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
inputSchema = tool.inputSchema
|
||||
outputSchema = getattr(tool, "outputSchema", None)
|
||||
|
||||
custom_fileds = inputSchema.get("$defs", {})
|
||||
required_fields = inputSchema.get("required", [])
|
||||
properties = inputSchema.get("properties", {})
|
||||
form_model_name = f"{endpoint_name}_form_model"
|
||||
|
||||
model_fields = get_model_fields(
|
||||
form_model_name,
|
||||
properties,
|
||||
required_fields,
|
||||
custom_fileds,
|
||||
)
|
||||
if outputSchema:
|
||||
output_model_name = f"{endpoint_name}_output_model"
|
||||
output_model_fields = get_model_fields(
|
||||
output_model_name,
|
||||
outputSchema.get("properties", {}),
|
||||
outputSchema.get("required", []),
|
||||
outputSchema.get("$defs", {}),
|
||||
)
|
||||
else:
|
||||
output_model_fields = None
|
||||
|
||||
tool_handler = get_tool_handler(
|
||||
session, endpoint_name, form_model_name, model_fields
|
||||
session, endpoint_name, form_model_name, model_fields, output_model_fields
|
||||
)
|
||||
|
||||
app.post(
|
||||
@@ -98,7 +108,10 @@ async def lifespan(app: FastAPI):
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
yield
|
||||
if server_type == "sse":
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (
|
||||
reader,
|
||||
writer,
|
||||
):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
@@ -133,7 +146,6 @@ async def run(
|
||||
ssl_certfile = kwargs.get("ssl_certfile")
|
||||
ssl_keyfile = kwargs.get("ssl_keyfile")
|
||||
path_prefix = kwargs.get("path_prefix") or "/"
|
||||
|
||||
main_app = FastAPI(
|
||||
title=name,
|
||||
description=description,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Any, Dict, List, Type, Union, ForwardRef
|
||||
from pydantic import create_model, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from mcp import ClientSession, types
|
||||
import json
|
||||
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
PARSE_ERROR,
|
||||
@@ -11,9 +12,11 @@ from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
import json
|
||||
from pydantic import Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
MCP_ERROR_TO_HTTP_STATUS = {
|
||||
PARSE_ERROR: 400,
|
||||
@@ -51,6 +54,7 @@ def _process_schema_property(
|
||||
model_name_prefix: str,
|
||||
prop_name: str,
|
||||
is_required: bool,
|
||||
custom_fields: Optional[Dict] = None,
|
||||
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
|
||||
"""
|
||||
Recursively processes a schema property to determine its Python type hint
|
||||
@@ -60,6 +64,12 @@ def _process_schema_property(
|
||||
A tuple containing (python_type_hint, pydantic_field).
|
||||
The pydantic_field contains default value and description.
|
||||
"""
|
||||
if "$ref" in prop_schema:
|
||||
ref = prop_schema["$ref"]
|
||||
ref = ref.split("/")[-1]
|
||||
assert ref in custom_fields, "Custom field not found"
|
||||
prop_schema = custom_fields[ref]
|
||||
|
||||
prop_type = prop_schema.get("type")
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
default_value = ... if is_required else prop_schema.get("default", None)
|
||||
@@ -111,7 +121,12 @@ def _process_schema_property(
|
||||
for name, schema in nested_properties.items():
|
||||
is_nested_required = name in nested_required
|
||||
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
||||
_model_cache, schema, nested_model_name, name, is_nested_required
|
||||
_model_cache,
|
||||
schema,
|
||||
nested_model_name,
|
||||
name,
|
||||
is_nested_required,
|
||||
custom_fields,
|
||||
)
|
||||
|
||||
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
||||
@@ -136,7 +151,8 @@ def _process_schema_property(
|
||||
items_schema,
|
||||
f"{model_name_prefix}_{prop_name}",
|
||||
"item",
|
||||
False, # Items aren't required at this level
|
||||
False, # Items aren't required at this level,
|
||||
custom_fields,
|
||||
)
|
||||
list_type_hint = List[item_type_hint]
|
||||
return list_type_hint, pydantic_field
|
||||
@@ -155,7 +171,7 @@ def _process_schema_property(
|
||||
return Any, pydantic_field
|
||||
|
||||
|
||||
def get_model_fields(form_model_name, properties, required_fields):
|
||||
def get_model_fields(form_model_name, properties, required_fields, custom_fields=None):
|
||||
model_fields = {}
|
||||
|
||||
_model_cache: Dict[str, Type] = {}
|
||||
@@ -163,21 +179,33 @@ def get_model_fields(form_model_name, properties, required_fields):
|
||||
for param_name, param_schema in properties.items():
|
||||
is_required = param_name in required_fields
|
||||
python_type_hint, pydantic_field_info = _process_schema_property(
|
||||
_model_cache, param_schema, form_model_name, param_name, is_required
|
||||
_model_cache,
|
||||
param_schema,
|
||||
form_model_name,
|
||||
param_name,
|
||||
is_required,
|
||||
custom_fields,
|
||||
)
|
||||
# Use the generated type hint and Field info
|
||||
model_fields[param_name] = (python_type_hint, pydantic_field_info)
|
||||
return model_fields
|
||||
|
||||
|
||||
def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
def get_tool_handler(
|
||||
session, endpoint_name, form_model_name, model_fields, output_model_fileds=None
|
||||
):
|
||||
if model_fields:
|
||||
FormModel = create_model(form_model_name, **model_fields)
|
||||
OutputModel = (
|
||||
create_model(f"{endpoint_name}_output_model", **output_model_fileds)
|
||||
if output_model_fileds
|
||||
else Any
|
||||
)
|
||||
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(form_data: FormModel):
|
||||
async def tool(form_data: FormModel) -> OutputModel:
|
||||
args = form_data.model_dump(exclude_none=True)
|
||||
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user