feat: Make ENABLE_WEBSOCKET_SUPPORT disable polling entirely to allow multiple replicas without sticky sessions.

See https://socket.io/docs/v4/using-multiple-nodes/ for details why this was done.

Also create a redis key to track which replica is running the cleanup job
This commit is contained in:
Jason Kidd 2024-12-06 13:03:01 -08:00
parent a38934bd23
commit 8f51681801
No known key found for this signature in database
GPG Key ID: 72BF942827539044
4 changed files with 87 additions and 32 deletions

View File

@ -264,6 +264,7 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
@ -932,6 +933,7 @@ async def get_app_config(request: Request):
"enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT,
**(
{
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,

View File

@ -11,7 +11,7 @@ from open_webui.env import (
WEBSOCKET_REDIS_URL,
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict
from open_webui.socket.utils import RedisDict, RedisLock
from open_webui.env import (
GLOBAL_LOG_LEVEL,
@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis":
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
client_manager=mgr,
@ -40,54 +38,78 @@ else:
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
)
# Timeout duration in seconds
TIMEOUT_DURATION = 3
# Dictionary to maintain the user pool
run_cleanup = True
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2,
)
run_cleanup = clean_up_lock.aquire_lock()
renew_func = clean_up_lock.renew_lock
release_func = clean_up_lock.release_lock
else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
release_func = renew_func = lambda: True
async def periodic_usage_pool_cleanup():
while True:
now = int(time.time())
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
if not run_cleanup:
log.debug("Usage pool cleanup lock already exists. Not running it.")
return
log.debug("Running periodic_usage_pool_cleanup")
try:
while True:
if not renew_func():
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
raise Exception("Unable to renew usage pool cleanup lock.")
for sid in expired_sids:
del connections[sid]
now = int(time.time())
send_usage = False
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
if not connections:
log.debug(f"Cleaning up model {model_id} from usage pool")
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
for sid in expired_sids:
del connections[sid]
# Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()})
if not connections:
log.debug(f"Cleaning up model {model_id} from usage pool")
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
await asyncio.sleep(TIMEOUT_DURATION)
send_usage = True
if send_usage:
# Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()})
await asyncio.sleep(TIMEOUT_DURATION)
finally:
release_func()
app = socketio.ASGIApp(

View File

@ -1,5 +1,33 @@
import json
import redis
import uuid
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
self.lock_obtained = self.redis.set(
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
)
return self.lock_obtained
def renew_lock(self):
# xx=True will only set this key if it _has_ already been set
return self.redis.set(
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
)
def release_lock(self):
lock_value = self.redis.get(self.lock_name)
if lock_value and lock_value.decode("utf-8") == self.lock_id:
self.redis.delete(self.lock_name)
class RedisDict:

View File

@ -38,13 +38,15 @@
let loaded = false;
const BREAKPOINT = 768;
const setupSocket = () => {
const setupSocket = (disableWebSocketPolling) => {
console.log('Disabled websocket polling', disableWebSocketPolling);
const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
reconnection: true,
reconnectionDelay: 1000,
reconnectionDelayMax: 5000,
randomizationFactor: 0.5,
path: '/ws/socket.io',
transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'],
auth: { token: localStorage.token }
});
@ -125,8 +127,9 @@
await config.set(backendConfig);
await WEBUI_NAME.set(backendConfig.name);
const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true;
if ($config) {
setupSocket();
setupSocket(disableWebSocketPolling);
if (localStorage.token) {
// Get Session User Info