Add streamable http support

This commit is contained in:
Taylor Wilsdon 2025-05-11 13:48:22 -04:00
parent e37d0ebd27
commit cd6ea93224

View File

@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client # Added import
from starlette.routing import Mount
@ -115,6 +116,18 @@ async def lifespan(app: FastAPI):
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
if server_type == "streamablehttp":
# Assuming args[0] will be the URL, similar to SSE
# streamablehttp_client also returns a get_session_id_callback, which we are not using for now
async with streamablehttp_client(url=args[0]) as (
reader,
writer,
_, # get_session_id_callback - not used by ClientSession directly
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
async def run(
@ -129,7 +142,7 @@ async def run(
strict_auth = kwargs.get("strict_auth", False)
# MCP Server
server_type = kwargs.get("server_type") # "stdio" or "sse" or "http"
server_type = kwargs.get("server_type") # "stdio", "sse", or "streamablehttp"
server_command = kwargs.get("server_command")
# MCP Config
@ -168,16 +181,18 @@ async def run(
if server_type == "sse":
main_app.state.server_type = "sse"
main_app.state.args = server_command[0]
main_app.state.args = server_command[0] # Expects URL as the first element
main_app.state.api_dependency = api_dependency
elif server_command:
elif server_type == "streamablehttp":
main_app.state.server_type = "streamablehttp"
main_app.state.args = server_command[0] # Expects URL as the first element
main_app.state.api_dependency = api_dependency
elif server_command: # This handles stdio
main_app.state.server_type = "stdio" # Explicitly set type
main_app.state.command = server_command[0]
main_app.state.args = server_command[1:]
main_app.state.env = os.environ.copy()
main_app.state.api_dependency = api_dependency
elif config_path:
with open(config_path, "r") as f:
config_data = json.load(f)
@ -205,13 +220,23 @@ async def run(
if server_cfg.get("command"):
# stdio
# stdio
sub_app.state.server_type = "stdio"
sub_app.state.command = server_cfg["command"]
sub_app.state.args = server_cfg.get("args", [])
sub_app.state.env = {**os.environ, **server_cfg.get("env", {})}
if server_cfg.get("url"):
# SSE
server_config_type = server_cfg.get("type")
if server_config_type == "sse" and server_cfg.get("url"):
sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"]
elif server_config_type == "streamablehttp" and server_cfg.get("url"):
sub_app.state.server_type = "streamablehttp"
sub_app.state.args = server_cfg["url"]
elif not server_config_type and server_cfg.get("url"): # Fallback for old SSE config
sub_app.state.server_type = "sse"
sub_app.state.args = server_cfg["url"]
# Add middleware to protect also documentation and spec
if api_key and strict_auth: