mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge 46586eb5c8
into 663f7312bb
This commit is contained in:
commit
e20eed7093
@ -61,6 +61,12 @@ def main(
|
||||
path_prefix: Annotated[
|
||||
Optional[str], typer.Option("--path-prefix", help="URL prefix")
|
||||
] = None,
|
||||
tools_timeout: Annotated[
|
||||
Optional[int], typer.Option("--tools-timeout", help="Timeout for waiting tools")
|
||||
] = 15,
|
||||
tools_interval: Annotated[
|
||||
Optional[int], typer.Option("--tools-interval", help="Polling interval for tools")
|
||||
] = 1,
|
||||
):
|
||||
server_command = None
|
||||
if not config_path:
|
||||
@ -131,6 +137,8 @@ def main(
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
path_prefix=path_prefix,
|
||||
tools_timeout=tools_timeout,
|
||||
tools_interval=tools_interval,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -18,11 +18,16 @@ from starlette.routing import Mount
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler, wait_list_tools
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
async def create_dynamic_endpoints(
|
||||
app: FastAPI,
|
||||
api_dependency=None,
|
||||
tools_timeout: int = 15,
|
||||
tools_interval: int = 1,
|
||||
):
|
||||
session: ClientSession = app.state.session
|
||||
if not session:
|
||||
raise ValueError("Session is not initialized in the app state.")
|
||||
@ -36,7 +41,11 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
)
|
||||
app.version = server_info.version or app.version
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
try:
|
||||
tools_result = await wait_list_tools(session, timeout=tools_timeout, interval=tools_interval)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to retrieve tools from MCP server: {e}")
|
||||
|
||||
tools = tools_result.tools
|
||||
|
||||
for tool in tools:
|
||||
@ -87,6 +96,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
args = args if isinstance(args, list) else [args]
|
||||
api_dependency = getattr(app.state, "api_dependency", None)
|
||||
tools_timeout = getattr(app.state, "tools_timeout", 5)
|
||||
tools_interval = getattr(app.state, "tools_interval", 1)
|
||||
|
||||
if (server_type == "stdio" and not command) or (
|
||||
server_type == "sse" and not args[0]
|
||||
@ -110,7 +121,12 @@ async def lifespan(app: FastAPI):
|
||||
async with stdio_client(server_params) as (reader, writer):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
await create_dynamic_endpoints(
|
||||
app,
|
||||
api_dependency=api_dependency,
|
||||
tools_timeout=tools_timeout,
|
||||
tools_interval=tools_interval,
|
||||
)
|
||||
yield
|
||||
if server_type == "sse":
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (
|
||||
@ -119,7 +135,12 @@ async def lifespan(app: FastAPI):
|
||||
):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
await create_dynamic_endpoints(
|
||||
app,
|
||||
api_dependency=api_dependency,
|
||||
tools_timeout=tools_timeout,
|
||||
tools_interval=tools_interval,
|
||||
)
|
||||
yield
|
||||
if server_type == "streamablehttp" or server_type == "streamable_http":
|
||||
# Ensure URL has trailing slash to avoid redirects
|
||||
@ -159,6 +180,10 @@ async def run(
|
||||
# MCP Config
|
||||
config_path = kwargs.get("config_path")
|
||||
|
||||
# MCP Tool
|
||||
tools_timeout = kwargs.get("tools_timeout", 15)
|
||||
tools_interval = kwargs.get("tools_interval", 1)
|
||||
|
||||
# mcpo server
|
||||
name = kwargs.get("name") or "MCP OpenAPI Proxy"
|
||||
description = (
|
||||
@ -204,6 +229,8 @@ async def run(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
main_app.state.tools_timeout = tools_timeout
|
||||
main_app.state.tools_interval = tools_interval
|
||||
|
||||
# Add middleware to protect also documentation and spec
|
||||
if api_key and strict_auth:
|
||||
@ -291,6 +318,8 @@ async def run(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
sub_app.state.tools_timeout = tools_timeout
|
||||
sub_app.state.tools_interval = tools_interval
|
||||
|
||||
if server_cfg.get("command"):
|
||||
# stdio
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||
|
||||
@ -27,6 +28,25 @@ MCP_ERROR_TO_HTTP_STATUS = {
|
||||
}
|
||||
|
||||
|
||||
async def wait_list_tools(session, timeout: int = 15, interval: int = 1) -> list:
|
||||
"""
|
||||
Polls the MCP server until the tool list is available or timeout is reached.
|
||||
Raises TimeoutError if tools are not available within the timeout.
|
||||
"""
|
||||
start = asyncio.get_event_loop().time()
|
||||
last_error = None
|
||||
while True:
|
||||
try:
|
||||
tools_result = await session.list_tools()
|
||||
if getattr(tools_result, "tools", None):
|
||||
return tools_result
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if asyncio.get_event_loop().time() - start > timeout:
|
||||
raise TimeoutError(f"Timed out waiting for MCP server tools: {last_error}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
def process_tool_response(result: CallToolResult) -> list:
|
||||
"""Universal response processor for all tool endpoints"""
|
||||
response = []
|
||||
|
Loading…
Reference in New Issue
Block a user