diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e7f602311..ea6babd59 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -264,6 +264,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + ENABLE_WEBSOCKET_SUPPORT, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, @@ -932,6 +933,7 @@ async def get_app_config(request: Request): "enable_api_key": app.state.config.ENABLE_API_KEY, "enable_signup": app.state.config.ENABLE_SIGNUP, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, + "disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT, **( { "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 8343be666..d043ce066 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -11,7 +11,7 @@ from open_webui.env import ( WEBSOCKET_REDIS_URL, ) from open_webui.utils.auth import decode_token -from open_webui.socket.utils import RedisDict +from open_webui.socket.utils import RedisDict, RedisLock from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis": sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", - transports=( - ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"] - ), + transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, client_manager=mgr, @@ -40,54 +38,78 @@ else: sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", - transports=( - ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"] - ), + transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, always_connect=True, ) +# Timeout duration in seconds +TIMEOUT_DURATION = 3 + # Dictionary to maintain the user pool +run_cleanup = True if WEBSOCKET_MANAGER == "redis": + log.debug("Using Redis to manage websockets.") SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) + + clean_up_lock = RedisLock( + redis_url=WEBSOCKET_REDIS_URL, + lock_name="usage_cleanup_lock", + timeout_secs=TIMEOUT_DURATION * 2, + ) + run_cleanup = clean_up_lock.aquire_lock() + renew_func = clean_up_lock.renew_lock + release_func = clean_up_lock.release_lock else: SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} - - -# Timeout duration in seconds -TIMEOUT_DURATION = 3 + release_func = renew_func = lambda: True async def periodic_usage_pool_cleanup(): - while True: - now = int(time.time()) - for model_id, connections in list(USAGE_POOL.items()): - # Creating a list of sids to remove if they have timed out - expired_sids = [ - sid - for sid, details in connections.items() - if now - details["updated_at"] > TIMEOUT_DURATION - ] + if not run_cleanup: + log.debug("Usage pool cleanup lock already exists. Not running it.") + return + log.debug("Running periodic_usage_pool_cleanup") + try: + while True: + if not renew_func(): + log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.") + raise Exception("Unable to renew usage pool cleanup lock.") - for sid in expired_sids: - del connections[sid] + now = int(time.time()) + send_usage = False + for model_id, connections in list(USAGE_POOL.items()): + # Creating a list of sids to remove if they have timed out + expired_sids = [ + sid + for sid, details in connections.items() + if now - details["updated_at"] > TIMEOUT_DURATION + ] - if not connections: - log.debug(f"Cleaning up model {model_id} from usage pool") - del USAGE_POOL[model_id] - else: - USAGE_POOL[model_id] = connections + for sid in expired_sids: + del connections[sid] - # Emit updated usage information after cleaning - await sio.emit("usage", {"models": get_models_in_use()}) + if not connections: + log.debug(f"Cleaning up model {model_id} from usage pool") + del USAGE_POOL[model_id] + else: + USAGE_POOL[model_id] = connections - await asyncio.sleep(TIMEOUT_DURATION) + send_usage = True + + if send_usage: + # Emit updated usage information after cleaning + await sio.emit("usage", {"models": get_models_in_use()}) + + await asyncio.sleep(TIMEOUT_DURATION) + finally: + release_func() app = socketio.ASGIApp( diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 1862ff439..d3de87b05 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,5 +1,33 @@ import json import redis +import uuid + + +class RedisLock: + def __init__(self, redis_url, lock_name, timeout_secs): + self.lock_name = lock_name + self.lock_id = str(uuid.uuid4()) + self.timeout_secs = timeout_secs + self.lock_obtained = False + self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + + def aquire_lock(self): + # nx=True will only set this key if it _hasn't_ already been set + self.lock_obtained = self.redis.set( + self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs + ) + return self.lock_obtained + + def renew_lock(self): + # xx=True will only set this key if it _has_ already been set + return self.redis.set( + self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs + ) + + def release_lock(self): + lock_value = self.redis.get(self.lock_name) + if lock_value and lock_value.decode("utf-8") == self.lock_id: + self.redis.delete(self.lock_name) class RedisDict: diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index d1c30a96b..71ee2dc8b 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -38,13 +38,15 @@ let loaded = false; const BREAKPOINT = 768; - const setupSocket = () => { + const setupSocket = (disableWebSocketPolling) => { + console.log('Disabled websocket polling', disableWebSocketPolling); const _socket = io(`${WEBUI_BASE_URL}` || undefined, { reconnection: true, reconnectionDelay: 1000, reconnectionDelayMax: 5000, randomizationFactor: 0.5, path: '/ws/socket.io', + transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'], auth: { token: localStorage.token } }); @@ -125,8 +127,9 @@ await config.set(backendConfig); await WEBUI_NAME.set(backendConfig.name); + const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true; if ($config) { - setupSocket(); + setupSocket(disableWebSocketPolling); if (localStorage.token) { // Get Session User Info