From 698976add087cc6b4976d6333efa9d2383c39f22 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 8 Sep 2024 12:00:36 +0100 Subject: [PATCH] feat: add ENABLE_WEBSOCKET_SUPPORT to force socket.io to ignore websocket upgrades --- backend/open_webui/apps/socket/main.py | 10 +++++++++- backend/open_webui/config.py | 6 ++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 5985bc524..777d877bf 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,9 +2,17 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users +from open_webui.config import ENABLE_WEBSOCKET_SUPPORT from open_webui.utils.utils import decode_token -sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi") +sio = socketio.AsyncServer( + cors_allowed_origins=[], + async_mode="asgi", + transports=( + ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT.value else ["polling"] + ), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT.value, +) app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5ccb40d47..ac34001d6 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -810,6 +810,12 @@ ENABLE_MESSAGE_RATING = PersistentConfig( os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", ) +ENABLE_WEBSOCKET_SUPPORT = PersistentConfig( + "ENABLE_WEBSOCKET_SUPPORT", + "ui.enable_websocket_support", + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true", +) + def validate_cors_origins(origins): for origin in origins: