mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Improved error handling, fixed couple of errors
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Type, Union, ForwardRef
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends
|
||||
@@ -9,11 +9,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Mount
|
||||
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler, ToolResponse
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
f"/{endpoint_name}",
|
||||
summary=endpoint_name.replace("_", " ").title(),
|
||||
description=endpoint_description,
|
||||
response_model=ToolResponse,
|
||||
response_model_exclude_none=True,
|
||||
dependencies=[Depends(api_dependency)] if api_dependency else [],
|
||||
)(tool_handler)
|
||||
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
from typing import Any, Dict, List, Type, Union, ForwardRef
|
||||
from pydantic import create_model, Field
|
||||
from typing import Any, Dict, List, Type, Union, ForwardRef, Optional
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from mcp import ClientSession, types
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
response: Optional[Any] = None
|
||||
errorMessage: Optional[str] = None
|
||||
errorData: Optional[Any] = None
|
||||
|
||||
def process_tool_response(result: CallToolResult) -> list:
|
||||
"""Universal response processor for all tool endpoints"""
|
||||
response = []
|
||||
@@ -63,7 +69,7 @@ def _process_schema_property(
|
||||
for name, schema in nested_properties.items():
|
||||
is_nested_required = name in nested_required
|
||||
nested_type_hint, nested_pydantic_field = _process_schema_property(
|
||||
schema, nested_model_name, name, is_nested_required
|
||||
_model_cache, schema, nested_model_name, name, is_nested_required
|
||||
)
|
||||
|
||||
nested_fields[name] = (nested_type_hint, nested_pydantic_field)
|
||||
@@ -84,6 +90,7 @@ def _process_schema_property(
|
||||
|
||||
# Recursively determine the type of items in the array
|
||||
item_type_hint, _ = _process_schema_property(
|
||||
_model_cache,
|
||||
items_schema,
|
||||
f"{model_name_prefix}_{prop_name}",
|
||||
"item",
|
||||
@@ -126,13 +133,31 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
|
||||
async def tool(form_data: FormModel):
|
||||
async def tool(form_data: FormModel) -> ToolResponse:
|
||||
args = form_data.model_dump(exclude_none=True)
|
||||
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
return process_tool_response(result)
|
||||
if result.isError:
|
||||
errorMessage = "Unknown tool execution error"
|
||||
if result.content and isinstance(result.content[0], types.TextContent):
|
||||
errorMessage = result.content[0].text
|
||||
return ToolResponse(errorMessage=errorMessage)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
return ToolResponse(response=final_response)
|
||||
|
||||
except McpError as e:
|
||||
print(f"MCP Error calling {endpoint_name}: {e.error}")
|
||||
return ToolResponse(
|
||||
errorMessage=e.error.message,
|
||||
errorData=e.error.data,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error calling {endpoint_name}: {e}")
|
||||
return ToolResponse(errorMessage=f"An unexpected internal error occurred: {e}")
|
||||
|
||||
return tool
|
||||
|
||||
@@ -142,12 +167,33 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
|
||||
def make_endpoint_func_no_args(
|
||||
endpoint_name: str, session: ClientSession
|
||||
): # Parameterless endpoint
|
||||
async def tool(): # No parameters
|
||||
async def tool() -> ToolResponse: # No parameters
|
||||
print(f"Calling endpoint: {endpoint_name}, with no args")
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
) # Empty dict
|
||||
return process_tool_response(result) # Same processor
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
) # Empty dict
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
if result.content and isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
return ToolResponse(errorMessage=error_message)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = response_data[0] if len(response_data) == 1 else response_data
|
||||
return ToolResponse(response=final_response)
|
||||
|
||||
except McpError as e:
|
||||
print(f"MCP Error calling {endpoint_name}: {e.error}")
|
||||
# Propagate the error received from MCP
|
||||
return ToolResponse(
|
||||
errorMessage=e.error.message,
|
||||
errorData=e.error.data,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error calling {endpoint_name}: {e}")
|
||||
return ToolResponse(errorMessage=f"An unexpected internal error occurred: {e}")
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from src.mcpo.main import _process_schema_property, _model_cache
|
||||
from src.mcpo.utils.main import _process_schema_property
|
||||
|
||||
|
||||
_model_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -16,7 +19,9 @@ def test_process_simple_string_required():
|
||||
schema = {"type": "string", "description": "A simple string"}
|
||||
expected_type = str
|
||||
expected_field = Field(default=..., description="A simple string")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
@@ -26,7 +31,9 @@ def test_process_simple_integer_optional():
|
||||
schema = {"type": "integer", "default": 10}
|
||||
expected_type = int
|
||||
expected_field = Field(default=10, description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
@@ -38,7 +45,9 @@ def test_process_simple_boolean_optional_no_default():
|
||||
expected_field = Field(
|
||||
default=None, description=""
|
||||
) # Default is None if not required and no default specified
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
@@ -48,7 +57,9 @@ def test_process_simple_number():
|
||||
schema = {"type": "number"}
|
||||
expected_type = float
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
@@ -57,7 +68,9 @@ def test_process_unknown_type():
|
||||
schema = {"type": "unknown"}
|
||||
expected_type = Any
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
@@ -66,7 +79,9 @@ def test_process_array_of_strings():
|
||||
schema = {"type": "array", "items": {"type": "string"}}
|
||||
expected_type = List[str]
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
@@ -75,7 +90,9 @@ def test_process_array_of_any_missing_items():
|
||||
schema = {"type": "array"} # Missing "items"
|
||||
expected_type = List[Any]
|
||||
expected_field = Field(default=None, description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", False
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
@@ -87,12 +104,13 @@ def test_process_simple_object():
|
||||
"required": ["name"],
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert result_field.description == expected_field.description
|
||||
assert issubclass(result_type, BaseModel) # Check if it's a Pydantic model
|
||||
assert "test_prop_model" in _model_cache # Check caching
|
||||
|
||||
# Check fields of the generated model
|
||||
model_fields = result_type.model_fields
|
||||
@@ -120,11 +138,10 @@ def test_process_nested_object():
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(
|
||||
schema, "test", "outer_prop", True
|
||||
_model_cache, schema, "test", "outer_prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert "test_outer_prop_model" in _model_cache
|
||||
assert issubclass(result_type, BaseModel)
|
||||
|
||||
outer_model_fields = result_type.model_fields
|
||||
@@ -133,7 +150,6 @@ def test_process_nested_object():
|
||||
|
||||
nested_model_type = outer_model_fields["user"].annotation
|
||||
assert issubclass(nested_model_type, BaseModel)
|
||||
assert "test_outer_prop_model_user_model" in _model_cache
|
||||
|
||||
nested_model_fields = nested_model_type.model_fields
|
||||
assert "id" in nested_model_fields
|
||||
@@ -151,11 +167,12 @@ def test_process_array_of_objects():
|
||||
},
|
||||
}
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
|
||||
assert result_field.default == expected_field.default
|
||||
assert str(result_type).startswith("typing.List[") # Check it's a List
|
||||
assert "test_prop_item_model" in _model_cache
|
||||
|
||||
# Get the inner type from List[...]
|
||||
item_model_type = result_type.__args__[0]
|
||||
@@ -171,7 +188,9 @@ def test_process_empty_object():
|
||||
schema = {"type": "object", "properties": {}}
|
||||
expected_type = Dict[str, Any] # Should default to Dict[str, Any] if no properties
|
||||
expected_field = Field(default=..., description="")
|
||||
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
|
||||
result_type, result_field = _process_schema_property(
|
||||
_model_cache, schema, "test", "prop", True
|
||||
)
|
||||
assert result_type == expected_type
|
||||
assert result_field.default == expected_field.default
|
||||
|
||||
@@ -183,19 +202,25 @@ def test_model_caching():
|
||||
"required": ["id"],
|
||||
}
|
||||
# First call
|
||||
result_type1, _ = _process_schema_property(schema, "cache_test", "obj1", True)
|
||||
result_type1, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj1", True
|
||||
)
|
||||
model_name = "cache_test_obj1_model"
|
||||
assert model_name in _model_cache
|
||||
assert _model_cache[model_name] == result_type1
|
||||
|
||||
# Second call with same structure but different prefix/prop name (should generate new)
|
||||
result_type2, _ = _process_schema_property(schema, "cache_test", "obj2", True)
|
||||
result_type2, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj2", True
|
||||
)
|
||||
model_name2 = "cache_test_obj2_model"
|
||||
assert model_name2 in _model_cache
|
||||
assert _model_cache[model_name2] == result_type2
|
||||
assert result_type1 != result_type2 # Different models
|
||||
|
||||
# Third call identical to the first (should return cached model)
|
||||
result_type3, _ = _process_schema_property(schema, "cache_test", "obj1", True)
|
||||
result_type3, _ = _process_schema_property(
|
||||
_model_cache, schema, "cache_test", "obj1", True
|
||||
)
|
||||
assert result_type3 == result_type1 # Should be the same cached object
|
||||
assert len(_model_cache) == 2 # Only two unique models created
|
||||
|
||||
Reference in New Issue
Block a user