mirror of
				https://github.com/open-webui/mcpo
				synced 2025-06-26 18:26:58 +00:00 
			
		
		
		
	Merge pull request #105 from JinY0ung-Shin/main
feat: support cusotm BaseModel "$def" and Output schema
This commit is contained in:
		
						commit
						1566a34c9c
					
				| @ -4,15 +4,14 @@ from contextlib import AsyncExitStack, asynccontextmanager | ||||
| from typing import Optional | ||||
| 
 | ||||
| import uvicorn | ||||
| from fastapi import FastAPI, Depends | ||||
| from fastapi import Depends, FastAPI | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from mcp import ClientSession, StdioServerParameters | ||||
| from mcp.client.sse import sse_client | ||||
| from mcp.client.stdio import stdio_client | ||||
| from starlette.routing import Mount | ||||
| 
 | ||||
| 
 | ||||
| from mcp import ClientSession, StdioServerParameters | ||||
| from mcp.client.stdio import stdio_client | ||||
| from mcp.client.sse import sse_client | ||||
| 
 | ||||
| from mcpo.utils.main import get_model_fields, get_tool_handler | ||||
| from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware | ||||
| 
 | ||||
| @ -37,21 +36,32 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): | ||||
|     for tool in tools: | ||||
|         endpoint_name = tool.name | ||||
|         endpoint_description = tool.description | ||||
|         schema = tool.inputSchema | ||||
| 
 | ||||
|         required_fields = schema.get("required", []) | ||||
|         properties = schema.get("properties", {}) | ||||
|         inputSchema = tool.inputSchema | ||||
|         outputSchema = getattr(tool, "outputSchema", None) | ||||
| 
 | ||||
|         custom_fileds = inputSchema.get("$defs", {}) | ||||
|         required_fields = inputSchema.get("required", []) | ||||
|         properties = inputSchema.get("properties", {}) | ||||
|         form_model_name = f"{endpoint_name}_form_model" | ||||
| 
 | ||||
|         model_fields = get_model_fields( | ||||
|             form_model_name, | ||||
|             properties, | ||||
|             required_fields, | ||||
|             custom_fileds, | ||||
|         ) | ||||
|         if outputSchema: | ||||
|             output_model_name = f"{endpoint_name}_output_model" | ||||
|             output_model_fields = get_model_fields( | ||||
|                 output_model_name, | ||||
|                 outputSchema.get("properties", {}), | ||||
|                 outputSchema.get("required", []), | ||||
|                 outputSchema.get("$defs", {}), | ||||
|             ) | ||||
|         else: | ||||
|             output_model_fields = None | ||||
| 
 | ||||
|         tool_handler = get_tool_handler( | ||||
|             session, endpoint_name, form_model_name, model_fields | ||||
|             session, endpoint_name, form_model_name, model_fields, output_model_fields | ||||
|         ) | ||||
| 
 | ||||
|         app.post( | ||||
| @ -98,7 +108,10 @@ async def lifespan(app: FastAPI): | ||||
|                     await create_dynamic_endpoints(app, api_dependency=api_dependency) | ||||
|                     yield | ||||
|         if server_type == "sse": | ||||
|             async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer): | ||||
|             async with sse_client(url=args[0], sse_read_timeout=None) as ( | ||||
|                 reader, | ||||
|                 writer, | ||||
|             ): | ||||
|                 async with ClientSession(reader, writer) as session: | ||||
|                     app.state.session = session | ||||
|                     await create_dynamic_endpoints(app, api_dependency=api_dependency) | ||||
| @ -133,7 +146,6 @@ async def run( | ||||
|     ssl_certfile = kwargs.get("ssl_certfile") | ||||
|     ssl_keyfile = kwargs.get("ssl_keyfile") | ||||
|     path_prefix = kwargs.get("path_prefix") or "/" | ||||
| 
 | ||||
|     main_app = FastAPI( | ||||
|         title=name, | ||||
|         description=description, | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| from typing import Any, Dict, List, Type, Union, ForwardRef | ||||
| from pydantic import create_model, Field | ||||
| from pydantic.fields import FieldInfo | ||||
| from mcp import ClientSession, types | ||||
| import json | ||||
| from typing import Any, Dict, ForwardRef, List, Optional, Type, Union | ||||
| 
 | ||||
| from fastapi import HTTPException | ||||
| 
 | ||||
| from mcp import ClientSession, types | ||||
| from mcp.types import ( | ||||
|     CallToolResult, | ||||
|     PARSE_ERROR, | ||||
| @ -11,9 +12,11 @@ from mcp.types import ( | ||||
|     INVALID_PARAMS, | ||||
|     INTERNAL_ERROR, | ||||
| ) | ||||
| 
 | ||||
| from mcp.shared.exceptions import McpError | ||||
| 
 | ||||
| import json | ||||
| from pydantic import Field, create_model | ||||
| from pydantic.fields import FieldInfo | ||||
| 
 | ||||
| MCP_ERROR_TO_HTTP_STATUS = { | ||||
|     PARSE_ERROR: 400, | ||||
| @ -51,6 +54,7 @@ def _process_schema_property( | ||||
|     model_name_prefix: str, | ||||
|     prop_name: str, | ||||
|     is_required: bool, | ||||
|     custom_fields: Optional[Dict] = None, | ||||
| ) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: | ||||
|     """ | ||||
|     Recursively processes a schema property to determine its Python type hint | ||||
| @ -60,6 +64,12 @@ def _process_schema_property( | ||||
|         A tuple containing (python_type_hint, pydantic_field). | ||||
|         The pydantic_field contains default value and description. | ||||
|     """ | ||||
|     if "$ref" in prop_schema: | ||||
|         ref = prop_schema["$ref"] | ||||
|         ref = ref.split("/")[-1] | ||||
|         assert ref in custom_fields, "Custom field not found" | ||||
|         prop_schema = custom_fields[ref] | ||||
| 
 | ||||
|     prop_type = prop_schema.get("type") | ||||
|     prop_desc = prop_schema.get("description", "") | ||||
|     default_value = ... if is_required else prop_schema.get("default", None) | ||||
| @ -111,7 +121,12 @@ def _process_schema_property( | ||||
|         for name, schema in nested_properties.items(): | ||||
|             is_nested_required = name in nested_required | ||||
|             nested_type_hint, nested_pydantic_field = _process_schema_property( | ||||
|                 _model_cache, schema, nested_model_name, name, is_nested_required | ||||
|                 _model_cache, | ||||
|                 schema, | ||||
|                 nested_model_name, | ||||
|                 name, | ||||
|                 is_nested_required, | ||||
|                 custom_fields, | ||||
|             ) | ||||
| 
 | ||||
|             nested_fields[name] = (nested_type_hint, nested_pydantic_field) | ||||
| @ -136,7 +151,8 @@ def _process_schema_property( | ||||
|             items_schema, | ||||
|             f"{model_name_prefix}_{prop_name}", | ||||
|             "item", | ||||
|             False,  # Items aren't required at this level | ||||
|             False,  # Items aren't required at this level, | ||||
|             custom_fields, | ||||
|         ) | ||||
|         list_type_hint = List[item_type_hint] | ||||
|         return list_type_hint, pydantic_field | ||||
| @ -155,7 +171,7 @@ def _process_schema_property( | ||||
|         return Any, pydantic_field | ||||
| 
 | ||||
| 
 | ||||
| def get_model_fields(form_model_name, properties, required_fields): | ||||
| def get_model_fields(form_model_name, properties, required_fields, custom_fields=None): | ||||
|     model_fields = {} | ||||
| 
 | ||||
|     _model_cache: Dict[str, Type] = {} | ||||
| @ -163,21 +179,33 @@ def get_model_fields(form_model_name, properties, required_fields): | ||||
|     for param_name, param_schema in properties.items(): | ||||
|         is_required = param_name in required_fields | ||||
|         python_type_hint, pydantic_field_info = _process_schema_property( | ||||
|             _model_cache, param_schema, form_model_name, param_name, is_required | ||||
|             _model_cache, | ||||
|             param_schema, | ||||
|             form_model_name, | ||||
|             param_name, | ||||
|             is_required, | ||||
|             custom_fields, | ||||
|         ) | ||||
|         # Use the generated type hint and Field info | ||||
|         model_fields[param_name] = (python_type_hint, pydantic_field_info) | ||||
|     return model_fields | ||||
| 
 | ||||
| 
 | ||||
| def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | ||||
| def get_tool_handler( | ||||
|     session, endpoint_name, form_model_name, model_fields, output_model_fileds=None | ||||
| ): | ||||
|     if model_fields: | ||||
|         FormModel = create_model(form_model_name, **model_fields) | ||||
|         OutputModel = ( | ||||
|             create_model(f"{endpoint_name}_output_model", **output_model_fileds) | ||||
|             if output_model_fileds | ||||
|             else Any | ||||
|         ) | ||||
| 
 | ||||
|         def make_endpoint_func( | ||||
|             endpoint_name: str, FormModel, session: ClientSession | ||||
|         ):  # Parameterized endpoint | ||||
|             async def tool(form_data: FormModel): | ||||
|             async def tool(form_data: FormModel) -> OutputModel: | ||||
|                 args = form_data.model_dump(exclude_none=True) | ||||
|                 print(f"Calling endpoint: {endpoint_name}, with args: {args}") | ||||
|                 try: | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user