From 698976add087cc6b4976d6333efa9d2383c39f22 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 8 Sep 2024 12:00:36 +0100 Subject: [PATCH 1/3] feat: add ENABLE_WEBSOCKET_SUPPORT to force socket.io to ignore websocket upgrades --- backend/open_webui/apps/socket/main.py | 10 +++++++++- backend/open_webui/config.py | 6 ++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 5985bc524..777d877bf 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,9 +2,17 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users +from open_webui.config import ENABLE_WEBSOCKET_SUPPORT from open_webui.utils.utils import decode_token -sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi") +sio = socketio.AsyncServer( + cors_allowed_origins=[], + async_mode="asgi", + transports=( + ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT.value else ["polling"] + ), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT.value, +) app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5ccb40d47..ac34001d6 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -810,6 +810,12 @@ ENABLE_MESSAGE_RATING = PersistentConfig( os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", ) +ENABLE_WEBSOCKET_SUPPORT = PersistentConfig( + "ENABLE_WEBSOCKET_SUPPORT", + "ui.enable_websocket_support", + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true", +) + def validate_cors_origins(origins): for origin in origins: From 827c41925137570599e876435ca2f65b3a6007b2 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Mon, 9 Sep 2024 23:17:17 +0100 Subject: [PATCH 2/3] feat: add ENABLE_WEBSOCKET_SUPPORT to force socket.io to ignore websocket upgrades --- backend/open_webui/apps/socket/main.py | 9 ++++----- backend/open_webui/config.py | 6 ------ backend/open_webui/env.py | 4 ++++ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 777d877bf..d6f1a4999 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,16 +2,15 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users -from open_webui.config import ENABLE_WEBSOCKET_SUPPORT +from open_webui.env import ENABLE_WEBSOCKET_SUPPORT from open_webui.utils.utils import decode_token sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", - transports=( - ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT.value else ["polling"] - ), - allow_upgrades=ENABLE_WEBSOCKET_SUPPORT.value, + transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, + always_connect=True, ) app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index ac34001d6..5ccb40d47 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -810,12 +810,6 @@ ENABLE_MESSAGE_RATING = PersistentConfig( os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", ) -ENABLE_WEBSOCKET_SUPPORT = PersistentConfig( - "ENABLE_WEBSOCKET_SUPPORT", - "ui.enable_websocket_support", - os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true", -) - def validate_cors_origins(origins): for origin in origins: diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index b716769c2..8683bb370 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -273,3 +273,7 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( if WEBUI_AUTH and WEBUI_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +ENABLE_WEBSOCKET_SUPPORT = ( + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" +) From 9401f6c8217821489d7950d0c27e3a94d46591e3 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 8 Sep 2024 11:54:56 +0100 Subject: [PATCH 3/3] fix: workaround socketio upstream bug when websockets are not available --- backend/open_webui/main.py | 18 ++++++++++++++++++ src/routes/+layout.svelte | 13 +++---------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9e6aa6fab..73f51c7d7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -812,6 +812,24 @@ async def update_embedding_function(request: Request, call_next): return response +@app.middleware("http") +async def inspect_websocket(request: Request, call_next): + if ( + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" + ): + upgrade = (request.headers.get("Upgrade") or "").lower() + connection = (request.headers.get("Connection") or "").lower().split(",") + # Check that there's the correct headers for an upgrade, else reject the connection + # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 + if upgrade != "websocket" or "upgrade" not in connection: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid WebSocket upgrade request"}, + ) + return await call_next(request) + + app.mount("/ws", socket_app) app.mount("/ollama", ollama_app) diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 9b209fec0..b71ac0d79 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -38,27 +38,20 @@ let loaded = false; const BREAKPOINT = 768; - const setupSocket = (websocket = true) => { + const setupSocket = () => { const _socket = io(`${WEBUI_BASE_URL}` || undefined, { reconnection: true, reconnectionDelay: 1000, reconnectionDelayMax: 5000, randomizationFactor: 0.5, path: '/ws/socket.io', - auth: { token: localStorage.token }, - transports: websocket ? ['websocket'] : ['polling'] + auth: { token: localStorage.token } }); socket.set(_socket); _socket.on('connect_error', (err) => { - if (err.message.includes('websocket')) { - console.log('WebSocket connection failed, falling back to polling'); - _socket.close(); - setupSocket(false); - } else { - console.log('connect_error', err); - } + console.log('connect_error', err); }); _socket.on('connect', () => {