diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index e41ef8412..8355cbcd8 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,16 +2,29 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users -from open_webui.env import ENABLE_WEBSOCKET_SUPPORT +from open_webui.env import ( + ENABLE_WEBSOCKET_SUPPORT, + WEBSOCKET_MANAGER, + WEBSOCKET_REDIS_URL, +) from open_webui.utils.utils import decode_token -sio = socketio.AsyncServer( - cors_allowed_origins=[], - async_mode="asgi", - transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), - allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, - always_connect=True, -) + +if WEBSOCKET_MANAGER == "redis": + mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) + sio = socketio.AsyncServer(client_manager=mgr) +else: + sio = socketio.AsyncServer( + cors_allowed_origins=[], + async_mode="asgi", + transports=( + ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"] + ), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, + always_connect=True, + ) + + app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 89422e57b..504eeea54 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -302,3 +302,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ENABLE_WEBSOCKET_SUPPORT = ( os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" ) + +WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") + +WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")