This commit is contained in:
JunXiang 2025-06-06 09:46:06 -07:00 committed by GitHub
commit e20eed7093
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 5 deletions

View File

@ -61,6 +61,12 @@ def main(
path_prefix: Annotated[
Optional[str], typer.Option("--path-prefix", help="URL prefix")
] = None,
tools_timeout: Annotated[
Optional[int], typer.Option("--tools-timeout", help="Timeout for waiting tools")
] = 15,
tools_interval: Annotated[
Optional[int], typer.Option("--tools-interval", help="Polling interval for tools")
] = 1,
):
server_command = None
if not config_path:
@ -131,6 +137,8 @@ def main(
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
path_prefix=path_prefix,
tools_timeout=tools_timeout,
tools_interval=tools_interval,
)
)

View File

@ -18,11 +18,16 @@ from starlette.routing import Mount
logger = logging.getLogger(__name__)
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.main import get_model_fields, get_tool_handler, wait_list_tools
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
async def create_dynamic_endpoints(
app: FastAPI,
api_dependency=None,
tools_timeout: int = 15,
tools_interval: int = 1,
):
session: ClientSession = app.state.session
if not session:
raise ValueError("Session is not initialized in the app state.")
@ -36,7 +41,11 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
)
app.version = server_info.version or app.version
tools_result = await session.list_tools()
try:
tools_result = await wait_list_tools(session, timeout=tools_timeout, interval=tools_interval)
except Exception as e:
raise RuntimeError(f"Failed to retrieve tools from MCP server: {e}")
tools = tools_result.tools
for tool in tools:
@ -87,6 +96,8 @@ async def lifespan(app: FastAPI):
args = args if isinstance(args, list) else [args]
api_dependency = getattr(app.state, "api_dependency", None)
tools_timeout = getattr(app.state, "tools_timeout", 5)
tools_interval = getattr(app.state, "tools_interval", 1)
if (server_type == "stdio" and not command) or (
server_type == "sse" and not args[0]
@ -110,7 +121,12 @@ async def lifespan(app: FastAPI):
async with stdio_client(server_params) as (reader, writer):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
await create_dynamic_endpoints(
app,
api_dependency=api_dependency,
tools_timeout=tools_timeout,
tools_interval=tools_interval,
)
yield
if server_type == "sse":
async with sse_client(url=args[0], sse_read_timeout=None) as (
@ -119,7 +135,12 @@ async def lifespan(app: FastAPI):
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
await create_dynamic_endpoints(
app,
api_dependency=api_dependency,
tools_timeout=tools_timeout,
tools_interval=tools_interval,
)
yield
if server_type == "streamablehttp" or server_type == "streamable_http":
# Ensure URL has trailing slash to avoid redirects
@ -159,6 +180,10 @@ async def run(
# MCP Config
config_path = kwargs.get("config_path")
# MCP Tool
tools_timeout = kwargs.get("tools_timeout", 15)
tools_interval = kwargs.get("tools_interval", 1)
# mcpo server
name = kwargs.get("name") or "MCP OpenAPI Proxy"
description = (
@ -204,6 +229,8 @@ async def run(
allow_methods=["*"],
allow_headers=["*"],
)
main_app.state.tools_timeout = tools_timeout
main_app.state.tools_interval = tools_interval
# Add middleware to protect also documentation and spec
if api_key and strict_auth:
@ -291,6 +318,8 @@ async def run(
allow_methods=["*"],
allow_headers=["*"],
)
sub_app.state.tools_timeout = tools_timeout
sub_app.state.tools_interval = tools_interval
if server_cfg.get("command"):
# stdio

View File

@ -1,3 +1,4 @@
import asyncio
import json
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
@ -27,6 +28,25 @@ MCP_ERROR_TO_HTTP_STATUS = {
}
async def wait_list_tools(session, timeout: int = 15, interval: int = 1) -> list:
"""
Polls the MCP server until the tool list is available or timeout is reached.
Raises TimeoutError if tools are not available within the timeout.
"""
start = asyncio.get_event_loop().time()
last_error = None
while True:
try:
tools_result = await session.list_tools()
if getattr(tools_result, "tools", None):
return tools_result
except Exception as e:
last_error = e
if asyncio.get_event_loop().time() - start > timeout:
raise TimeoutError(f"Timed out waiting for MCP server tools: {last_error}")
await asyncio.sleep(interval)
def process_tool_response(result: CallToolResult) -> list:
"""Universal response processor for all tool endpoints"""
response = []