This commit is contained in:
edelauna 2025-06-19 13:08:11 +00:00 committed by GitHub
commit 50076f615c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 11 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import json
import os
import logging
@ -9,7 +10,6 @@ import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from starlette.routing import Mount
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
from mcpo.utils.sse import sse_client_loop
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -117,16 +118,20 @@ async def lifespan(app: FastAPI):
yield
if server_type == "sse":
headers = getattr(app.state, "headers", None)
async with sse_client(
url=args[0], sse_read_timeout=None, headers=headers
) as (
reader,
writer,
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
sse_task = asyncio.create_task(sse_client_loop(
url=args[0],
headers=headers,
api_dependency=api_dependency,
create_dynamic_endpoints=create_dynamic_endpoints,
app=app
))
yield
if sse_task:
sse_task.cancel()
try:
await sse_task
except asyncio.CancelledError:
pass
if server_type == "streamablehttp" or server_type == "streamable_http":
headers = getattr(app.state, "headers", None)

View File

@ -0,0 +1,34 @@
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from mcpo.utils.sse import sse_client_loop
from fastapi import FastAPI
@pytest.mark.asyncio
async def test_sse_client_loop():
# Mock dependencies
app = FastAPI()
api_dependency = MagicMock()
create_dynamic_endpoints = AsyncMock()
# Mock the sse_client context manager
mock_reader = AsyncMock()
mock_writer = AsyncMock()
mock_session = AsyncMock()
# Set up the mock reader to return a message once and then raise a CancelledError
mock_reader.receive.side_effect = [MagicMock(), asyncio.CancelledError()]
with patch('mcpo.utils.sse.sse_client') as mock_sse_client, \
patch('mcpo.utils.sse.ClientSession') as mock_client_session:
mock_sse_client.return_value.__aenter__.return_value = (mock_reader, mock_writer)
mock_client_session.return_value.__aenter__.return_value = mock_session
# Run the sse_client_loop
await sse_client_loop("http://test-url", {}, api_dependency, create_dynamic_endpoints, app)
# Verify that the functions were called as expected
create_dynamic_endpoints.assert_awaited_once_with(app, api_dependency=api_dependency)
mock_reader.receive.assert_awaited()
assert mock_reader.receive.await_count == 2 # Once for the message, once for the CancelledError

35
src/mcpo/utils/sse.py Normal file
View File

@ -0,0 +1,35 @@
import asyncio
import logging
import traceback
from mcp import ClientSession
from mcp.client.sse import sse_client
logger = logging.getLogger(__name__)
async def sse_client_loop(url, headers, api_dependency, create_dynamic_endpoints, app):
while True:
try:
async with sse_client(
url=url, sse_read_timeout=None, headers=headers
) as (reader, writer):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
while True:
try:
msg = await asyncio.wait_for(reader.receive(), timeout=60)
if isinstance(msg, Exception):
raise msg
except asyncio.TimeoutError:
logger.warning("SSE read timeout, reconnecting...")
break
except asyncio.CancelledError:
logger.info("SSE client connection cancelled, shutting down...")
return
except Exception as e:
if isinstance(e, ExceptionGroup):
root_cause = e.exceptions[0]
logger.error(f"SSE client error: {type(root_cause).__name__}: {root_cause}")
else:
logger.error(f"SSE client error: {e}")
await asyncio.sleep(5) # Wait before reconnecting