mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Merge 9cdeeb91e7
into 105063963d
This commit is contained in:
commit
50076f615c
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
@ -9,7 +10,6 @@ import uvicorn
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from starlette.routing import Mount
|
||||
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
from mcpo.utils.sse import sse_client_loop
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
@ -117,16 +118,20 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
if server_type == "sse":
|
||||
headers = getattr(app.state, "headers", None)
|
||||
async with sse_client(
|
||||
url=args[0], sse_read_timeout=None, headers=headers
|
||||
) as (
|
||||
reader,
|
||||
writer,
|
||||
):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
yield
|
||||
sse_task = asyncio.create_task(sse_client_loop(
|
||||
url=args[0],
|
||||
headers=headers,
|
||||
api_dependency=api_dependency,
|
||||
create_dynamic_endpoints=create_dynamic_endpoints,
|
||||
app=app
|
||||
))
|
||||
yield
|
||||
if sse_task:
|
||||
sse_task.cancel()
|
||||
try:
|
||||
await sse_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if server_type == "streamablehttp" or server_type == "streamable_http":
|
||||
headers = getattr(app.state, "headers", None)
|
||||
|
||||
|
34
src/mcpo/tests/test_sse.py
Normal file
34
src/mcpo/tests/test_sse.py
Normal file
@ -0,0 +1,34 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from mcpo.utils.sse import sse_client_loop
|
||||
from fastapi import FastAPI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_client_loop():
|
||||
# Mock dependencies
|
||||
app = FastAPI()
|
||||
api_dependency = MagicMock()
|
||||
create_dynamic_endpoints = AsyncMock()
|
||||
|
||||
# Mock the sse_client context manager
|
||||
mock_reader = AsyncMock()
|
||||
mock_writer = AsyncMock()
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# Set up the mock reader to return a message once and then raise a CancelledError
|
||||
mock_reader.receive.side_effect = [MagicMock(), asyncio.CancelledError()]
|
||||
|
||||
with patch('mcpo.utils.sse.sse_client') as mock_sse_client, \
|
||||
patch('mcpo.utils.sse.ClientSession') as mock_client_session:
|
||||
|
||||
mock_sse_client.return_value.__aenter__.return_value = (mock_reader, mock_writer)
|
||||
mock_client_session.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
# Run the sse_client_loop
|
||||
await sse_client_loop("http://test-url", {}, api_dependency, create_dynamic_endpoints, app)
|
||||
|
||||
# Verify that the functions were called as expected
|
||||
create_dynamic_endpoints.assert_awaited_once_with(app, api_dependency=api_dependency)
|
||||
mock_reader.receive.assert_awaited()
|
||||
assert mock_reader.receive.await_count == 2 # Once for the message, once for the CancelledError
|
35
src/mcpo/utils/sse.py
Normal file
35
src/mcpo/utils/sse.py
Normal file
@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def sse_client_loop(url, headers, api_dependency, create_dynamic_endpoints, app):
|
||||
while True:
|
||||
try:
|
||||
async with sse_client(
|
||||
url=url, sse_read_timeout=None, headers=headers
|
||||
) as (reader, writer):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(reader.receive(), timeout=60)
|
||||
if isinstance(msg, Exception):
|
||||
raise msg
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("SSE read timeout, reconnecting...")
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info("SSE client connection cancelled, shutting down...")
|
||||
return
|
||||
except Exception as e:
|
||||
if isinstance(e, ExceptionGroup):
|
||||
root_cause = e.exceptions[0]
|
||||
logger.error(f"SSE client error: {type(root_cause).__name__}: {root_cause}")
|
||||
else:
|
||||
logger.error(f"SSE client error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
Loading…
Reference in New Issue
Block a user