Merge pull request #69 from nohaz-h/feature/sse-mcp-server-support

Feat/sse mcp server connection
This commit is contained in:
Tim Jaeryang Baek 2025-04-12 11:57:06 -07:00 committed by GitHub
commit d1b9f7182d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,7 @@ from starlette.routing import Mount
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key
@ -69,9 +69,12 @@ async def lifespan(app: FastAPI):
args = getattr(app.state, "args", [])
env = getattr(app.state, "env", {})
mcptype = "sse" if command == "sse" else "stdio"
args = args if isinstance(args, list) else [args]
api_dependency = getattr(app.state, "api_dependency", None)
if not command:
if (mcptype == "stdio" and not command) or (mcptype == "sse" and not args[0]):
async with AsyncExitStack() as stack:
for route in app.routes:
if isinstance(route, Mount) and isinstance(route.app, FastAPI):
@ -81,17 +84,24 @@ async def lifespan(app: FastAPI):
yield
else:
server_params = StdioServerParameters(
command=command,
args=args,
env={**env},
)
if mcptype == "stdio":
server_params = StdioServerParameters(
command=command,
args=args,
env={**env},
)
async with stdio_client(server_params) as (reader, writer):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
async with stdio_client(server_params) as (reader, writer):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
if mcptype == "sse":
async with sse_client(url=args[0]) as (reader, writer):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
async def run(
@ -166,9 +176,15 @@ async def run(
allow_headers=["*"],
)
sub_app.state.command = server_cfg["command"]
sub_app.state.args = server_cfg.get("args", [])
sub_app.state.env = {**os.environ, **server_cfg.get("env", {})}
if server_cfg.get("command"):
# stdio
sub_app.state.command = server_cfg["command"]
sub_app.state.args = server_cfg.get("args", [])
sub_app.state.env = {**os.environ, **server_cfg.get("env", {})}
if server_cfg.get("url"):
# SSE
sub_app.state.command = "sse"
sub_app.state.args = server_cfg["url"]
sub_app.state.api_dependency = api_dependency