From 8f51681801cc2a29efabdc86520c2219ef203b52 Mon Sep 17 00:00:00 2001 From: Jason Kidd Date: Fri, 6 Dec 2024 13:03:01 -0800 Subject: [PATCH 1/2] feat: Make ENABLE_WEBSOCKET_SUPPORT disable polling entirely to allow multiple replicas without sticky sessions. See https://socket.io/docs/v4/using-multiple-nodes/ for details why this was done. Also create a redis key to track which replica is running the cleanup job --- backend/open_webui/main.py | 2 + backend/open_webui/socket/main.py | 82 +++++++++++++++++++----------- backend/open_webui/socket/utils.py | 28 ++++++++++ src/routes/+layout.svelte | 7 ++- 4 files changed, 87 insertions(+), 32 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e7f602311..ea6babd59 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -264,6 +264,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + ENABLE_WEBSOCKET_SUPPORT, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, @@ -932,6 +933,7 @@ async def get_app_config(request: Request): "enable_api_key": app.state.config.ENABLE_API_KEY, "enable_signup": app.state.config.ENABLE_SIGNUP, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, + "disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT, **( { "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 8343be666..d043ce066 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -11,7 +11,7 @@ from open_webui.env import ( WEBSOCKET_REDIS_URL, ) from open_webui.utils.auth import decode_token -from open_webui.socket.utils import RedisDict +from open_webui.socket.utils import RedisDict, RedisLock from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis": sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", - transports=( - ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"] - ), + transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, client_manager=mgr, @@ -40,54 +38,78 @@ else: sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", - transports=( - ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"] - ), + transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, ) +# Timeout duration in seconds +TIMEOUT_DURATION = 3 + # Dictionary to maintain the user pool +run_cleanup = True if WEBSOCKET_MANAGER == "redis": + log.debug("Using Redis to manage websockets.") SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) + + clean_up_lock = RedisLock( + redis_url=WEBSOCKET_REDIS_URL, + lock_name="usage_cleanup_lock", + timeout_secs=TIMEOUT_DURATION * 2, + ) + run_cleanup = clean_up_lock.aquire_lock() + renew_func = clean_up_lock.renew_lock + release_func = clean_up_lock.release_lock else: SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} - - -# Timeout duration in seconds -TIMEOUT_DURATION = 3 + release_func = renew_func = lambda: True async def periodic_usage_pool_cleanup(): - while True: - now = int(time.time()) - for model_id, connections in list(USAGE_POOL.items()): - # Creating a list of sids to remove if they have timed out - expired_sids = [ - sid - for sid, details in connections.items() - if now - details["updated_at"] > TIMEOUT_DURATION - ] + if not run_cleanup: + log.debug("Usage pool cleanup lock already exists. Not running it.") + return + log.debug("Running periodic_usage_pool_cleanup") + try: + while True: + if not renew_func(): + log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.") + raise Exception("Unable to renew usage pool cleanup lock.") - for sid in expired_sids: - del connections[sid] + now = int(time.time()) + send_usage = False + for model_id, connections in list(USAGE_POOL.items()): + # Creating a list of sids to remove if they have timed out + expired_sids = [ + sid + for sid, details in connections.items() + if now - details["updated_at"] > TIMEOUT_DURATION + ] - if not connections: - log.debug(f"Cleaning up model {model_id} from usage pool") - del USAGE_POOL[model_id] - else: - USAGE_POOL[model_id] = connections + for sid in expired_sids: + del connections[sid] - # Emit updated usage information after cleaning - await sio.emit("usage", {"models": get_models_in_use()}) + if not connections: + log.debug(f"Cleaning up model {model_id} from usage pool") + del USAGE_POOL[model_id] + else: + USAGE_POOL[model_id] = connections - await asyncio.sleep(TIMEOUT_DURATION) + send_usage = True + + if send_usage: + # Emit updated usage information after cleaning + await sio.emit("usage", {"models": get_models_in_use()}) + + await asyncio.sleep(TIMEOUT_DURATION) + finally: + release_func() app = socketio.ASGIApp( diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 1862ff439..d3de87b05 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,5 +1,33 @@ import json import redis +import uuid + + +class RedisLock: + def __init__(self, redis_url, lock_name, timeout_secs): + self.lock_name = lock_name + self.lock_id = str(uuid.uuid4()) + self.timeout_secs = timeout_secs + self.lock_obtained = False + self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + + def aquire_lock(self): + # nx=True will only set this key if it _hasn't_ already been set + self.lock_obtained = self.redis.set( + self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs + ) + return self.lock_obtained + + def renew_lock(self): + # xx=True will only set this key if it _has_ already been set + return self.redis.set( + self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs + ) + + def release_lock(self): + lock_value = self.redis.get(self.lock_name) + if lock_value and lock_value.decode("utf-8") == self.lock_id: + self.redis.delete(self.lock_name) class RedisDict: diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index d1c30a96b..71ee2dc8b 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -38,13 +38,15 @@ let loaded = false; const BREAKPOINT = 768; - const setupSocket = () => { + const setupSocket = (disableWebSocketPolling) => { + console.log('Disabled websocket polling', disableWebSocketPolling); const _socket = io(`${WEBUI_BASE_URL}` || undefined, { reconnection: true, reconnectionDelay: 1000, reconnectionDelayMax: 5000, randomizationFactor: 0.5, path: '/ws/socket.io', + transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'], auth: { token: localStorage.token } }); @@ -125,8 +127,9 @@ await config.set(backendConfig); await WEBUI_NAME.set(backendConfig.name); + const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true; if ($config) { - setupSocket(); + setupSocket(disableWebSocketPolling); if (localStorage.token) { // Get Session User Info From e4573d0b6c023cedc6173cd96c9fc7b5ab75d520 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 18 Dec 2024 18:32:19 -0800 Subject: [PATCH 2/2] refac --- backend/open_webui/main.py | 2 +- src/routes/+layout.svelte | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ea6babd59..e3e3cde4d 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -933,7 +933,7 @@ async def get_app_config(request: Request): "enable_api_key": app.state.config.ENABLE_API_KEY, "enable_signup": app.state.config.ENABLE_SIGNUP, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, - "disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT, + "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, **( { "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 71ee2dc8b..d92e8c2ab 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -38,15 +38,14 @@ let loaded = false; const BREAKPOINT = 768; - const setupSocket = (disableWebSocketPolling) => { - console.log('Disabled websocket polling', disableWebSocketPolling); + const setupSocket = (enableWebsocket) => { const _socket = io(`${WEBUI_BASE_URL}` || undefined, { reconnection: true, reconnectionDelay: 1000, reconnectionDelayMax: 5000, randomizationFactor: 0.5, path: '/ws/socket.io', - transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'], + transports: enableWebsocket ? ['websocket'] : ['polling', 'websocket'], auth: { token: localStorage.token } }); @@ -127,9 +126,8 @@ await config.set(backendConfig); await WEBUI_NAME.set(backendConfig.name); - const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true; if ($config) { - setupSocket(disableWebSocketPolling); + setupSocket($config.features?.enable_websocket ?? true); if (localStorage.token) { // Get Session User Info