mirror of
				https://github.com/open-webui/mcpo
				synced 2025-06-26 18:26:58 +00:00 
			
		
		
		
	Improved error handling, fixed couple of errors
This commit is contained in:
		
							parent
							
								
									2c2c46eb30
								
							
						
					
					
						commit
						d46803d427
					
				| @ -1,7 +1,7 @@ | ||||
| import json | ||||
| import os | ||||
| from contextlib import AsyncExitStack, asynccontextmanager | ||||
| from typing import Dict, Any, Optional, List, Type, Union, ForwardRef | ||||
| from typing import Optional | ||||
| 
 | ||||
| import uvicorn | ||||
| from fastapi import FastAPI, Depends | ||||
| @ -9,11 +9,11 @@ from fastapi.middleware.cors import CORSMiddleware | ||||
| from starlette.routing import Mount | ||||
| 
 | ||||
| 
 | ||||
| from mcp import ClientSession, StdioServerParameters, types | ||||
| from mcp import ClientSession, StdioServerParameters | ||||
| from mcp.client.stdio import stdio_client | ||||
| 
 | ||||
| 
 | ||||
| from mcpo.utils.main import get_model_fields, get_tool_handler | ||||
| from mcpo.utils.main import get_model_fields, get_tool_handler, ToolResponse | ||||
| from mcpo.utils.auth import get_verify_api_key | ||||
| 
 | ||||
| 
 | ||||
| @ -58,6 +58,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): | ||||
|             f"/{endpoint_name}", | ||||
|             summary=endpoint_name.replace("_", " ").title(), | ||||
|             description=endpoint_description, | ||||
|             response_model=ToolResponse, | ||||
|             response_model_exclude_none=True, | ||||
|             dependencies=[Depends(api_dependency)] if api_dependency else [], | ||||
|         )(tool_handler) | ||||
| 
 | ||||
|  | ||||
| @ -1,12 +1,18 @@ | ||||
| from typing import Any, Dict, List, Type, Union, ForwardRef | ||||
| from pydantic import create_model, Field | ||||
| from typing import Any, Dict, List, Type, Union, ForwardRef, Optional | ||||
| from pydantic import BaseModel, create_model, Field | ||||
| from pydantic.fields import FieldInfo | ||||
| from mcp import ClientSession, types | ||||
| from mcp.types import CallToolResult | ||||
| from mcp.shared.exceptions import McpError | ||||
| 
 | ||||
| import json | ||||
| 
 | ||||
| 
 | ||||
| class ToolResponse(BaseModel): | ||||
|     response: Optional[Any] = None | ||||
|     errorMessage: Optional[str] = None | ||||
|     errorData: Optional[Any] = None | ||||
| 
 | ||||
| def process_tool_response(result: CallToolResult) -> list: | ||||
|     """Universal response processor for all tool endpoints""" | ||||
|     response = [] | ||||
| @ -63,7 +69,7 @@ def _process_schema_property( | ||||
|         for name, schema in nested_properties.items(): | ||||
|             is_nested_required = name in nested_required | ||||
|             nested_type_hint, nested_pydantic_field = _process_schema_property( | ||||
|                 schema, nested_model_name, name, is_nested_required | ||||
|                 _model_cache, schema, nested_model_name, name, is_nested_required | ||||
|             ) | ||||
| 
 | ||||
|             nested_fields[name] = (nested_type_hint, nested_pydantic_field) | ||||
| @ -84,6 +90,7 @@ def _process_schema_property( | ||||
| 
 | ||||
|         # Recursively determine the type of items in the array | ||||
|         item_type_hint, _ = _process_schema_property( | ||||
|             _model_cache, | ||||
|             items_schema, | ||||
|             f"{model_name_prefix}_{prop_name}", | ||||
|             "item", | ||||
| @ -126,13 +133,31 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | ||||
|         def make_endpoint_func( | ||||
|             endpoint_name: str, FormModel, session: ClientSession | ||||
|         ):  # Parameterized endpoint | ||||
| 
 | ||||
|             async def tool(form_data: FormModel): | ||||
|             async def tool(form_data: FormModel) -> ToolResponse: | ||||
|                 args = form_data.model_dump(exclude_none=True) | ||||
|                 print(f"Calling endpoint: {endpoint_name}, with args: {args}") | ||||
|                 try: | ||||
|                     result = await session.call_tool(endpoint_name, arguments=args) | ||||
| 
 | ||||
|                 result = await session.call_tool(endpoint_name, arguments=args) | ||||
|                 return process_tool_response(result) | ||||
|                     if result.isError: | ||||
|                         errorMessage = "Unknown tool execution error" | ||||
|                         if result.content and isinstance(result.content[0], types.TextContent): | ||||
|                             errorMessage = result.content[0].text | ||||
|                         return ToolResponse(errorMessage=errorMessage) | ||||
| 
 | ||||
|                     response_data = process_tool_response(result) | ||||
|                     final_response = response_data[0] if len(response_data) == 1 else response_data | ||||
|                     return ToolResponse(response=final_response) | ||||
| 
 | ||||
|                 except McpError as e: | ||||
|                     print(f"MCP Error calling {endpoint_name}: {e.error}") | ||||
|                     return ToolResponse( | ||||
|                         errorMessage=e.error.message, | ||||
|                         errorData=e.error.data, | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     print(f"Unexpected error calling {endpoint_name}: {e}") | ||||
|                     return ToolResponse(errorMessage=f"An unexpected internal error occurred: {e}") | ||||
| 
 | ||||
|             return tool | ||||
| 
 | ||||
| @ -142,12 +167,33 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields): | ||||
|         def make_endpoint_func_no_args( | ||||
|             endpoint_name: str, session: ClientSession | ||||
|         ):  # Parameterless endpoint | ||||
|             async def tool():  # No parameters | ||||
|             async def tool() -> ToolResponse:  # No parameters | ||||
|                 print(f"Calling endpoint: {endpoint_name}, with no args") | ||||
|                 result = await session.call_tool( | ||||
|                     endpoint_name, arguments={} | ||||
|                 )  # Empty dict | ||||
|                 return process_tool_response(result)  # Same processor | ||||
|                 try: | ||||
|                     result = await session.call_tool( | ||||
|                         endpoint_name, arguments={} | ||||
|                     )  # Empty dict | ||||
| 
 | ||||
|                     if result.isError: | ||||
|                         error_message = "Unknown tool execution error" | ||||
|                         if result.content and isinstance(result.content[0], types.TextContent): | ||||
|                             error_message = result.content[0].text | ||||
|                         return ToolResponse(errorMessage=error_message) | ||||
| 
 | ||||
|                     response_data = process_tool_response(result) | ||||
|                     final_response = response_data[0] if len(response_data) == 1 else response_data | ||||
|                     return ToolResponse(response=final_response) | ||||
| 
 | ||||
|                 except McpError as e: | ||||
|                     print(f"MCP Error calling {endpoint_name}: {e.error}") | ||||
|                     # Propagate the error received from MCP | ||||
|                     return ToolResponse( | ||||
|                         errorMessage=e.error.message, | ||||
|                         errorData=e.error.data, | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     print(f"Unexpected error calling {endpoint_name}: {e}") | ||||
|                     return ToolResponse(errorMessage=f"An unexpected internal error occurred: {e}") | ||||
| 
 | ||||
|             return tool | ||||
| 
 | ||||
|  | ||||
| @ -2,7 +2,10 @@ import pytest | ||||
| from pydantic import BaseModel, Field | ||||
| from typing import Any, List, Dict | ||||
| 
 | ||||
| from src.mcpo.main import _process_schema_property, _model_cache | ||||
| from src.mcpo.utils.main import _process_schema_property | ||||
| 
 | ||||
| 
 | ||||
| _model_cache = {} | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| @ -16,7 +19,9 @@ def test_process_simple_string_required(): | ||||
|     schema = {"type": "string", "description": "A simple string"} | ||||
|     expected_type = str | ||||
|     expected_field = Field(default=..., description="A simple string") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert result_field.description == expected_field.description | ||||
| @ -26,7 +31,9 @@ def test_process_simple_integer_optional(): | ||||
|     schema = {"type": "integer", "default": 10} | ||||
|     expected_type = int | ||||
|     expected_field = Field(default=10, description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", False) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", False | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert result_field.description == expected_field.description | ||||
| @ -38,7 +45,9 @@ def test_process_simple_boolean_optional_no_default(): | ||||
|     expected_field = Field( | ||||
|         default=None, description="" | ||||
|     )  # Default is None if not required and no default specified | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", False) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", False | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert result_field.description == expected_field.description | ||||
| @ -48,7 +57,9 @@ def test_process_simple_number(): | ||||
|     schema = {"type": "number"} | ||||
|     expected_type = float | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
| 
 | ||||
| @ -57,7 +68,9 @@ def test_process_unknown_type(): | ||||
|     schema = {"type": "unknown"} | ||||
|     expected_type = Any | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
| 
 | ||||
| @ -66,7 +79,9 @@ def test_process_array_of_strings(): | ||||
|     schema = {"type": "array", "items": {"type": "string"}} | ||||
|     expected_type = List[str] | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
| 
 | ||||
| @ -75,7 +90,9 @@ def test_process_array_of_any_missing_items(): | ||||
|     schema = {"type": "array"}  # Missing "items" | ||||
|     expected_type = List[Any] | ||||
|     expected_field = Field(default=None, description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", False) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", False | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
| 
 | ||||
| @ -87,12 +104,13 @@ def test_process_simple_object(): | ||||
|         "required": ["name"], | ||||
|     } | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
| 
 | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert result_field.description == expected_field.description | ||||
|     assert issubclass(result_type, BaseModel)  # Check if it's a Pydantic model | ||||
|     assert "test_prop_model" in _model_cache  # Check caching | ||||
| 
 | ||||
|     # Check fields of the generated model | ||||
|     model_fields = result_type.model_fields | ||||
| @ -120,11 +138,10 @@ def test_process_nested_object(): | ||||
|     } | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         schema, "test", "outer_prop", True | ||||
|         _model_cache, schema, "test", "outer_prop", True | ||||
|     ) | ||||
| 
 | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert "test_outer_prop_model" in _model_cache | ||||
|     assert issubclass(result_type, BaseModel) | ||||
| 
 | ||||
|     outer_model_fields = result_type.model_fields | ||||
| @ -133,7 +150,6 @@ def test_process_nested_object(): | ||||
| 
 | ||||
|     nested_model_type = outer_model_fields["user"].annotation | ||||
|     assert issubclass(nested_model_type, BaseModel) | ||||
|     assert "test_outer_prop_model_user_model" in _model_cache | ||||
| 
 | ||||
|     nested_model_fields = nested_model_type.model_fields | ||||
|     assert "id" in nested_model_fields | ||||
| @ -151,11 +167,12 @@ def test_process_array_of_objects(): | ||||
|         }, | ||||
|     } | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
| 
 | ||||
|     assert result_field.default == expected_field.default | ||||
|     assert str(result_type).startswith("typing.List[")  # Check it's a List | ||||
|     assert "test_prop_item_model" in _model_cache | ||||
| 
 | ||||
|     # Get the inner type from List[...] | ||||
|     item_model_type = result_type.__args__[0] | ||||
| @ -171,7 +188,9 @@ def test_process_empty_object(): | ||||
|     schema = {"type": "object", "properties": {}} | ||||
|     expected_type = Dict[str, Any]  # Should default to Dict[str, Any] if no properties | ||||
|     expected_field = Field(default=..., description="") | ||||
|     result_type, result_field = _process_schema_property(schema, "test", "prop", True) | ||||
|     result_type, result_field = _process_schema_property( | ||||
|         _model_cache, schema, "test", "prop", True | ||||
|     ) | ||||
|     assert result_type == expected_type | ||||
|     assert result_field.default == expected_field.default | ||||
| 
 | ||||
| @ -183,19 +202,25 @@ def test_model_caching(): | ||||
|         "required": ["id"], | ||||
|     } | ||||
|     # First call | ||||
|     result_type1, _ = _process_schema_property(schema, "cache_test", "obj1", True) | ||||
|     result_type1, _ = _process_schema_property( | ||||
|         _model_cache, schema, "cache_test", "obj1", True | ||||
|     ) | ||||
|     model_name = "cache_test_obj1_model" | ||||
|     assert model_name in _model_cache | ||||
|     assert _model_cache[model_name] == result_type1 | ||||
| 
 | ||||
|     # Second call with same structure but different prefix/prop name (should generate new) | ||||
|     result_type2, _ = _process_schema_property(schema, "cache_test", "obj2", True) | ||||
|     result_type2, _ = _process_schema_property( | ||||
|         _model_cache, schema, "cache_test", "obj2", True | ||||
|     ) | ||||
|     model_name2 = "cache_test_obj2_model" | ||||
|     assert model_name2 in _model_cache | ||||
|     assert _model_cache[model_name2] == result_type2 | ||||
|     assert result_type1 != result_type2  # Different models | ||||
| 
 | ||||
|     # Third call identical to the first (should return cached model) | ||||
|     result_type3, _ = _process_schema_property(schema, "cache_test", "obj1", True) | ||||
|     result_type3, _ = _process_schema_property( | ||||
|         _model_cache, schema, "cache_test", "obj1", True | ||||
|     ) | ||||
|     assert result_type3 == result_type1  # Should be the same cached object | ||||
|     assert len(_model_cache) == 2  # Only two unique models created | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user