mirror of
				https://github.com/open-webui/mcpo
				synced 2025-06-26 18:26:58 +00:00 
			
		
		
		
	Support Custom basemode and output schema
This commit is contained in:
		
							parent
							
								
									e392df0763
								
							
						
					
					
						commit
						464d886673
					
				| @ -1,12 +1,19 @@ | |||||||
| from typing import Any, Dict, List, Type, Union, ForwardRef |  | ||||||
| from pydantic import create_model, Field |  | ||||||
| from pydantic.fields import FieldInfo |  | ||||||
| from mcp import ClientSession, types |  | ||||||
| from fastapi import HTTPException |  | ||||||
| from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR |  | ||||||
| from mcp.shared.exceptions import McpError |  | ||||||
| 
 |  | ||||||
| import json | import json | ||||||
|  | from typing import Any, Dict, ForwardRef, List, Optional, Type, Union | ||||||
|  | 
 | ||||||
|  | from fastapi import HTTPException | ||||||
|  | from mcp import ClientSession, types | ||||||
|  | from mcp.shared.exceptions import McpError | ||||||
|  | from mcp.types import ( | ||||||
|  |     INTERNAL_ERROR, | ||||||
|  |     INVALID_PARAMS, | ||||||
|  |     INVALID_REQUEST, | ||||||
|  |     METHOD_NOT_FOUND, | ||||||
|  |     PARSE_ERROR, | ||||||
|  |     CallToolResult, | ||||||
|  | ) | ||||||
|  | from pydantic import Field, create_model | ||||||
|  | from pydantic.fields import FieldInfo | ||||||
| 
 | 
 | ||||||
| MCP_ERROR_TO_HTTP_STATUS = { | MCP_ERROR_TO_HTTP_STATUS = { | ||||||
|     PARSE_ERROR: 400, |     PARSE_ERROR: 400, | ||||||
| @ -44,6 +51,7 @@ def _process_schema_property( | |||||||
|     model_name_prefix: str, |     model_name_prefix: str, | ||||||
|     prop_name: str, |     prop_name: str, | ||||||
|     is_required: bool, |     is_required: bool, | ||||||
|  |     custom_fields: Optional[Dict] = None, | ||||||
| ) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: | ) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: | ||||||
|     """ |     """ | ||||||
|     Recursively processes a schema property to determine its Python type hint |     Recursively processes a schema property to determine its Python type hint | ||||||
| @ -53,6 +61,12 @@ def _process_schema_property( | |||||||
|         A tuple containing (python_type_hint, pydantic_field). |         A tuple containing (python_type_hint, pydantic_field). | ||||||
|         The pydantic_field contains default value and description. |         The pydantic_field contains default value and description. | ||||||
|     """ |     """ | ||||||
|  |     if "$ref" in prop_schema: | ||||||
|  |         ref = prop_schema["$ref"] | ||||||
|  |         ref = ref.split("/")[-1] | ||||||
|  |         assert ref in custom_fields, "Custom field not found" | ||||||
|  |         prop_schema = custom_fields[ref] | ||||||
|  | 
 | ||||||
|     prop_type = prop_schema.get("type") |     prop_type = prop_schema.get("type") | ||||||
|     prop_desc = prop_schema.get("description", "") |     prop_desc = prop_schema.get("description", "") | ||||||
|     default_value = ... if is_required else prop_schema.get("default", None) |     default_value = ... if is_required else prop_schema.get("default", None) | ||||||
| @ -73,7 +87,12 @@ def _process_schema_property( | |||||||
|         for name, schema in nested_properties.items(): |         for name, schema in nested_properties.items(): | ||||||
|             is_nested_required = name in nested_required |             is_nested_required = name in nested_required | ||||||
|             nested_type_hint, nested_pydantic_field = _process_schema_property( |             nested_type_hint, nested_pydantic_field = _process_schema_property( | ||||||
|                 _model_cache, schema, nested_model_name, name, is_nested_required |                 _model_cache, | ||||||
|  |                 schema, | ||||||
|  |                 nested_model_name, | ||||||
|  |                 name, | ||||||
|  |                 is_nested_required, | ||||||
|  |                 custom_fields, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             nested_fields[name] = (nested_type_hint, nested_pydantic_field) |             nested_fields[name] = (nested_type_hint, nested_pydantic_field) | ||||||
| @ -98,7 +117,8 @@ def _process_schema_property( | |||||||
|             items_schema, |             items_schema, | ||||||
|             f"{model_name_prefix}_{prop_name}", |             f"{model_name_prefix}_{prop_name}", | ||||||
|             "item", |             "item", | ||||||
|             False,  # Items aren't required at this level |             False,  # Items aren't required at this level, | ||||||
|  |             custom_fields, | ||||||
|         ) |         ) | ||||||
|         list_type_hint = List[item_type_hint] |         list_type_hint = List[item_type_hint] | ||||||
|         return list_type_hint, pydantic_field |         return list_type_hint, pydantic_field | ||||||
| @ -115,7 +135,7 @@ def _process_schema_property( | |||||||
|         return Any, pydantic_field |         return Any, pydantic_field | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_model_fields(form_model_name, properties, required_fields): | def get_model_fields(form_model_name, properties, required_fields, custom_fields=None): | ||||||
|     model_fields = {} |     model_fields = {} | ||||||
| 
 | 
 | ||||||
|     _model_cache: Dict[str, Type] = {} |     _model_cache: Dict[str, Type] = {} | ||||||
| @ -123,21 +143,33 @@ def get_model_fields(form_model_name, properties, required_fields): | |||||||
|     for param_name, param_schema in properties.items(): |     for param_name, param_schema in properties.items(): | ||||||
|         is_required = param_name in required_fields |         is_required = param_name in required_fields | ||||||
|         python_type_hint, pydantic_field_info = _process_schema_property( |         python_type_hint, pydantic_field_info = _process_schema_property( | ||||||
|             _model_cache, param_schema, form_model_name, param_name, is_required |             _model_cache, | ||||||
|  |             param_schema, | ||||||
|  |             form_model_name, | ||||||
|  |             param_name, | ||||||
|  |             is_required, | ||||||
|  |             custom_fields, | ||||||
|         ) |         ) | ||||||
|         # Use the generated type hint and Field info |         # Use the generated type hint and Field info | ||||||
|         model_fields[param_name] = (python_type_hint, pydantic_field_info) |         model_fields[param_name] = (python_type_hint, pydantic_field_info) | ||||||
|     return model_fields |     return model_fields | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | def get_tool_handler( | ||||||
|  |     session, endpoint_name, form_model_name, model_fields, output_model_fileds=None | ||||||
|  | ): | ||||||
|     if model_fields: |     if model_fields: | ||||||
|         FormModel = create_model(form_model_name, **model_fields) |         FormModel = create_model(form_model_name, **model_fields) | ||||||
|  |         OutputModel = ( | ||||||
|  |             create_model(f"{endpoint_name}_output_model", **output_model_fileds) | ||||||
|  |             if output_model_fileds | ||||||
|  |             else Any | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         def make_endpoint_func( |         def make_endpoint_func( | ||||||
|             endpoint_name: str, FormModel, session: ClientSession |             endpoint_name: str, FormModel, session: ClientSession | ||||||
|         ):  # Parameterized endpoint |         ):  # Parameterized endpoint | ||||||
|             async def tool(form_data: FormModel): |             async def tool(form_data: FormModel) -> OutputModel: | ||||||
|                 args = form_data.model_dump(exclude_none=True) |                 args = form_data.model_dump(exclude_none=True) | ||||||
|                 print(f"Calling endpoint: {endpoint_name}, with args: {args}") |                 print(f"Calling endpoint: {endpoint_name}, with args: {args}") | ||||||
|                 try: |                 try: | ||||||
| @ -158,7 +190,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | |||||||
|                         ) |                         ) | ||||||
| 
 | 
 | ||||||
|                     response_data = process_tool_response(result) |                     response_data = process_tool_response(result) | ||||||
|                     final_response = response_data[0] if len(response_data) == 1 else response_data |                     final_response = ( | ||||||
|  |                         response_data[0] if len(response_data) == 1 else response_data | ||||||
|  |                     ) | ||||||
|                     return final_response |                     return final_response | ||||||
| 
 | 
 | ||||||
|                 except McpError as e: |                 except McpError as e: | ||||||
| @ -206,7 +240,9 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | |||||||
|                         ) |                         ) | ||||||
| 
 | 
 | ||||||
|                     response_data = process_tool_response(result) |                     response_data = process_tool_response(result) | ||||||
|                     final_response = response_data[0] if len(response_data) == 1 else response_data |                     final_response = ( | ||||||
|  |                         response_data[0] if len(response_data) == 1 else response_data | ||||||
|  |                     ) | ||||||
|                     return final_response |                     return final_response | ||||||
| 
 | 
 | ||||||
|                 except McpError as e: |                 except McpError as e: | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user