Merge pull request #63 from bmen25124/error_handling

Improved error/result handling
This commit is contained in:
Tim Jaeryang Baek 2025-04-10 11:43:02 -07:00 committed by GitHub
commit 93caf158c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 133 additions and 29 deletions

View File

@ -1,7 +1,7 @@
import json
import os
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Dict, Any, Optional, List, Type, Union, ForwardRef
from typing import Optional
import uvicorn
from fastapi import FastAPI, Depends
@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.routing import Mount
from mcp import ClientSession, StdioServerParameters, types
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@ -58,6 +58,7 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
f"/{endpoint_name}",
summary=endpoint_name.replace("_", " ").title(),
description=endpoint_description,
response_model_exclude_none=True,
dependencies=[Depends(api_dependency)] if api_dependency else [],
)(tool_handler)

View File

@ -2,10 +2,20 @@ from typing import Any, Dict, List, Type, Union, ForwardRef
from pydantic import create_model, Field
from pydantic.fields import FieldInfo
from mcp import ClientSession, types
from mcp.types import CallToolResult
from fastapi import HTTPException
from mcp.types import CallToolResult, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR
from mcp.shared.exceptions import McpError
import json
MCP_ERROR_TO_HTTP_STATUS = {
PARSE_ERROR: 400,
INVALID_REQUEST: 400,
METHOD_NOT_FOUND: 404,
INVALID_PARAMS: 422,
INTERNAL_ERROR: 500,
}
def process_tool_response(result: CallToolResult) -> list:
"""Universal response processor for all tool endpoints"""
@ -127,13 +137,47 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel):
args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
result = await session.call_tool(endpoint_name, arguments=args)
result = await session.call_tool(endpoint_name, arguments=args)
return process_tool_response(result)
if result.isError:
error_message = "Unknown tool execution error"
error_data = None # Initialize error_data
if result.content:
if isinstance(result.content[0], types.TextContent):
error_message = result.content[0].text
detail = {"message": error_message}
if error_data is not None:
detail["data"] = error_data
raise HTTPException(
status_code=500,
detail=detail,
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
return final_response
except McpError as e:
print(f"MCP Error calling {endpoint_name}: {e.error}")
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
print(f"Unexpected error calling {endpoint_name}: {e}")
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
return tool
@ -145,10 +189,44 @@ def get_tool_handler(session, endpoint_name, form_model_name, model_fields):
): # Parameterless endpoint
async def tool(): # No parameters
print(f"Calling endpoint: {endpoint_name}, with no args")
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
return process_tool_response(result) # Same processor
try:
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
if result.isError:
error_message = "Unknown tool execution error"
if result.content:
if isinstance(result.content[0], types.TextContent):
error_message = result.content[0].text
detail = {"message": error_message}
raise HTTPException(
status_code=500,
detail=detail,
)
response_data = process_tool_response(result)
final_response = response_data[0] if len(response_data) == 1 else response_data
return final_response
except McpError as e:
print(f"MCP Error calling {endpoint_name}: {e.error}")
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
# Propagate the error received from MCP as an HTTP exception
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
print(f"Unexpected error calling {endpoint_name}: {e}")
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
return tool

View File

@ -2,7 +2,10 @@ import pytest
from pydantic import BaseModel, Field
from typing import Any, List, Dict
from src.mcpo.main import _process_schema_property, _model_cache
from src.mcpo.utils.main import _process_schema_property
_model_cache = {}
@pytest.fixture(autouse=True)
@ -16,7 +19,9 @@ def test_process_simple_string_required():
schema = {"type": "string", "description": "A simple string"}
expected_type = str
expected_field = Field(default=..., description="A simple string")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
@ -26,7 +31,9 @@ def test_process_simple_integer_optional():
schema = {"type": "integer", "default": 10}
expected_type = int
expected_field = Field(default=10, description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", False
)
assert result_type == expected_type
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
@ -38,7 +45,9 @@ def test_process_simple_boolean_optional_no_default():
expected_field = Field(
default=None, description=""
) # Default is None if not required and no default specified
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", False
)
assert result_type == expected_type
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
@ -48,7 +57,9 @@ def test_process_simple_number():
schema = {"type": "number"}
expected_type = float
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
@ -57,7 +68,9 @@ def test_process_unknown_type():
schema = {"type": "unknown"}
expected_type = Any
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
@ -66,7 +79,9 @@ def test_process_array_of_strings():
schema = {"type": "array", "items": {"type": "string"}}
expected_type = List[str]
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
@ -75,7 +90,9 @@ def test_process_array_of_any_missing_items():
schema = {"type": "array"} # Missing "items"
expected_type = List[Any]
expected_field = Field(default=None, description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", False)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", False
)
assert result_type == expected_type
assert result_field.default == expected_field.default
@ -87,12 +104,13 @@ def test_process_simple_object():
"required": ["name"],
}
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_field.default == expected_field.default
assert result_field.description == expected_field.description
assert issubclass(result_type, BaseModel) # Check if it's a Pydantic model
assert "test_prop_model" in _model_cache # Check caching
# Check fields of the generated model
model_fields = result_type.model_fields
@ -120,11 +138,10 @@ def test_process_nested_object():
}
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(
schema, "test", "outer_prop", True
_model_cache, schema, "test", "outer_prop", True
)
assert result_field.default == expected_field.default
assert "test_outer_prop_model" in _model_cache
assert issubclass(result_type, BaseModel)
outer_model_fields = result_type.model_fields
@ -133,7 +150,6 @@ def test_process_nested_object():
nested_model_type = outer_model_fields["user"].annotation
assert issubclass(nested_model_type, BaseModel)
assert "test_outer_prop_model_user_model" in _model_cache
nested_model_fields = nested_model_type.model_fields
assert "id" in nested_model_fields
@ -151,11 +167,12 @@ def test_process_array_of_objects():
},
}
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_field.default == expected_field.default
assert str(result_type).startswith("typing.List[") # Check it's a List
assert "test_prop_item_model" in _model_cache
# Get the inner type from List[...]
item_model_type = result_type.__args__[0]
@ -171,7 +188,9 @@ def test_process_empty_object():
schema = {"type": "object", "properties": {}}
expected_type = Dict[str, Any] # Should default to Dict[str, Any] if no properties
expected_field = Field(default=..., description="")
result_type, result_field = _process_schema_property(schema, "test", "prop", True)
result_type, result_field = _process_schema_property(
_model_cache, schema, "test", "prop", True
)
assert result_type == expected_type
assert result_field.default == expected_field.default
@ -183,19 +202,25 @@ def test_model_caching():
"required": ["id"],
}
# First call
result_type1, _ = _process_schema_property(schema, "cache_test", "obj1", True)
result_type1, _ = _process_schema_property(
_model_cache, schema, "cache_test", "obj1", True
)
model_name = "cache_test_obj1_model"
assert model_name in _model_cache
assert _model_cache[model_name] == result_type1
# Second call with same structure but different prefix/prop name (should generate new)
result_type2, _ = _process_schema_property(schema, "cache_test", "obj2", True)
result_type2, _ = _process_schema_property(
_model_cache, schema, "cache_test", "obj2", True
)
model_name2 = "cache_test_obj2_model"
assert model_name2 in _model_cache
assert _model_cache[model_name2] == result_type2
assert result_type1 != result_type2 # Different models
# Third call identical to the first (should return cached model)
result_type3, _ = _process_schema_property(schema, "cache_test", "obj1", True)
result_type3, _ = _process_schema_property(
_model_cache, schema, "cache_test", "obj1", True
)
assert result_type3 == result_type1 # Should be the same cached object
assert len(_model_cache) == 2 # Only two unique models created