feat: Make ENABLE_WEBSOCKET_SUPPORT disable polling entirely to allow multiple replicas without sticky sessions.

See https://socket.io/docs/v4/using-multiple-nodes/ for details why this was done.

Also create a redis key to track which replica is running the cleanup job
This commit is contained in:
Jason Kidd 2024-12-06 13:03:01 -08:00
parent a38934bd23
commit 8f51681801
No known key found for this signature in database
GPG Key ID: 72BF942827539044
4 changed files with 87 additions and 32 deletions

View File

@ -264,6 +264,7 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START, RESET_CONFIG_ON_START,
OFFLINE_MODE, OFFLINE_MODE,
@ -932,6 +933,7 @@ async def get_app_config(request: Request):
"enable_api_key": app.state.config.ENABLE_API_KEY, "enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP, "enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT,
**( **(
{ {
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,

View File

@ -11,7 +11,7 @@ from open_webui.env import (
WEBSOCKET_REDIS_URL, WEBSOCKET_REDIS_URL,
) )
from open_webui.utils.auth import decode_token from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict from open_webui.socket.utils import RedisDict, RedisLock
from open_webui.env import ( from open_webui.env import (
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis":
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
cors_allowed_origins=[], cors_allowed_origins=[],
async_mode="asgi", async_mode="asgi",
transports=( transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True, always_connect=True,
client_manager=mgr, client_manager=mgr,
@ -40,33 +38,52 @@ else:
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
cors_allowed_origins=[], cors_allowed_origins=[],
async_mode="asgi", async_mode="asgi",
transports=( transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True, always_connect=True,
) )
# Timeout duration in seconds
TIMEOUT_DURATION = 3
# Dictionary to maintain the user pool # Dictionary to maintain the user pool
run_cleanup = True
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2,
)
run_cleanup = clean_up_lock.aquire_lock()
renew_func = clean_up_lock.renew_lock
release_func = clean_up_lock.release_lock
else: else:
SESSION_POOL = {} SESSION_POOL = {}
USER_POOL = {} USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
release_func = renew_func = lambda: True
# Timeout duration in seconds
TIMEOUT_DURATION = 3
async def periodic_usage_pool_cleanup(): async def periodic_usage_pool_cleanup():
if not run_cleanup:
log.debug("Usage pool cleanup lock already exists. Not running it.")
return
log.debug("Running periodic_usage_pool_cleanup")
try:
while True: while True:
if not renew_func():
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
raise Exception("Unable to renew usage pool cleanup lock.")
now = int(time.time()) now = int(time.time())
send_usage = False
for model_id, connections in list(USAGE_POOL.items()): for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out # Creating a list of sids to remove if they have timed out
expired_sids = [ expired_sids = [
@ -84,10 +101,15 @@ async def periodic_usage_pool_cleanup():
else: else:
USAGE_POOL[model_id] = connections USAGE_POOL[model_id] = connections
send_usage = True
if send_usage:
# Emit updated usage information after cleaning # Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
await asyncio.sleep(TIMEOUT_DURATION) await asyncio.sleep(TIMEOUT_DURATION)
finally:
release_func()
app = socketio.ASGIApp( app = socketio.ASGIApp(

View File

@ -1,5 +1,33 @@
import json import json
import redis import redis
import uuid
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
self.lock_obtained = self.redis.set(
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
)
return self.lock_obtained
def renew_lock(self):
# xx=True will only set this key if it _has_ already been set
return self.redis.set(
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
)
def release_lock(self):
lock_value = self.redis.get(self.lock_name)
if lock_value and lock_value.decode("utf-8") == self.lock_id:
self.redis.delete(self.lock_name)
class RedisDict: class RedisDict:

View File

@ -38,13 +38,15 @@
let loaded = false; let loaded = false;
const BREAKPOINT = 768; const BREAKPOINT = 768;
const setupSocket = () => { const setupSocket = (disableWebSocketPolling) => {
console.log('Disabled websocket polling', disableWebSocketPolling);
const _socket = io(`${WEBUI_BASE_URL}` || undefined, { const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
reconnection: true, reconnection: true,
reconnectionDelay: 1000, reconnectionDelay: 1000,
reconnectionDelayMax: 5000, reconnectionDelayMax: 5000,
randomizationFactor: 0.5, randomizationFactor: 0.5,
path: '/ws/socket.io', path: '/ws/socket.io',
transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'],
auth: { token: localStorage.token } auth: { token: localStorage.token }
}); });
@ -125,8 +127,9 @@
await config.set(backendConfig); await config.set(backendConfig);
await WEBUI_NAME.set(backendConfig.name); await WEBUI_NAME.set(backendConfig.name);
const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true;
if ($config) { if ($config) {
setupSocket(); setupSocket(disableWebSocketPolling);
if (localStorage.token) { if (localStorage.token) {
// Get Session User Info // Get Session User Info