diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 5985bc524..d6f1a4999 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,9 +2,16 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users +from open_webui.env import ENABLE_WEBSOCKET_SUPPORT from open_webui.utils.utils import decode_token -sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi") +sio = socketio.AsyncServer( + cors_allowed_origins=[], + async_mode="asgi", + transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, + always_connect=True, +) app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index b716769c2..8683bb370 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -273,3 +273,7 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( if WEBUI_AUTH and WEBUI_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +ENABLE_WEBSOCKET_SUPPORT = ( + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" +) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8914cb491..a47115977 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -812,6 +812,24 @@ async def update_embedding_function(request: Request, call_next): return response +@app.middleware("http") +async def inspect_websocket(request: Request, call_next): + if ( + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" + ): + upgrade = (request.headers.get("Upgrade") or "").lower() + connection = (request.headers.get("Connection") or "").lower().split(",") + # Check that there's the correct headers for an upgrade, else reject the connection + # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 + if upgrade != "websocket" or "upgrade" not in connection: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid WebSocket upgrade request"}, + ) + return await call_next(request) + + app.mount("/ws", socket_app) app.mount("/ollama", ollama_app) diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 9b209fec0..b71ac0d79 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -38,27 +38,20 @@ let loaded = false; const BREAKPOINT = 768; - const setupSocket = (websocket = true) => { + const setupSocket = () => { const _socket = io(`${WEBUI_BASE_URL}` || undefined, { reconnection: true, reconnectionDelay: 1000, reconnectionDelayMax: 5000, randomizationFactor: 0.5, path: '/ws/socket.io', - auth: { token: localStorage.token }, - transports: websocket ? ['websocket'] : ['polling'] + auth: { token: localStorage.token } }); socket.set(_socket); _socket.on('connect_error', (err) => { - if (err.message.includes('websocket')) { - console.log('WebSocket connection failed, falling back to polling'); - _socket.close(); - setupSocket(false); - } else { - console.log('connect_error', err); - } + console.log('connect_error', err); }); _socket.on('connect', () => {