mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 06:42:47 +00:00
feat: Make ENABLE_WEBSOCKET_SUPPORT disable polling entirely to allow multiple replicas without sticky sessions.
See https://socket.io/docs/v4/using-multiple-nodes/ for details why this was done. Also create a redis key to track which replica is running the cleanup job
This commit is contained in:
parent
a38934bd23
commit
8f51681801
@ -264,6 +264,7 @@ from open_webui.env import (
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
RESET_CONFIG_ON_START,
|
||||
OFFLINE_MODE,
|
||||
@ -932,6 +933,7 @@ async def get_app_config(request: Request):
|
||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||
"disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT,
|
||||
**(
|
||||
{
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
|
@ -11,7 +11,7 @@ from open_webui.env import (
|
||||
WEBSOCKET_REDIS_URL,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict
|
||||
from open_webui.socket.utils import RedisDict, RedisLock
|
||||
|
||||
from open_webui.env import (
|
||||
GLOBAL_LOG_LEVEL,
|
||||
@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis":
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
client_manager=mgr,
|
||||
@ -40,54 +38,78 @@ else:
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
)
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
|
||||
# Dictionary to maintain the user pool
|
||||
|
||||
run_cleanup = True
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
log.debug("Using Redis to manage websockets.")
|
||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
|
||||
clean_up_lock = RedisLock(
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=TIMEOUT_DURATION * 2,
|
||||
)
|
||||
run_cleanup = clean_up_lock.aquire_lock()
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
release_func = clean_up_lock.release_lock
|
||||
else:
|
||||
SESSION_POOL = {}
|
||||
USER_POOL = {}
|
||||
USAGE_POOL = {}
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
release_func = renew_func = lambda: True
|
||||
|
||||
|
||||
async def periodic_usage_pool_cleanup():
|
||||
while True:
|
||||
now = int(time.time())
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
# Creating a list of sids to remove if they have timed out
|
||||
expired_sids = [
|
||||
sid
|
||||
for sid, details in connections.items()
|
||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||
]
|
||||
if not run_cleanup:
|
||||
log.debug("Usage pool cleanup lock already exists. Not running it.")
|
||||
return
|
||||
log.debug("Running periodic_usage_pool_cleanup")
|
||||
try:
|
||||
while True:
|
||||
if not renew_func():
|
||||
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
|
||||
raise Exception("Unable to renew usage pool cleanup lock.")
|
||||
|
||||
for sid in expired_sids:
|
||||
del connections[sid]
|
||||
now = int(time.time())
|
||||
send_usage = False
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
# Creating a list of sids to remove if they have timed out
|
||||
expired_sids = [
|
||||
sid
|
||||
for sid, details in connections.items()
|
||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||
]
|
||||
|
||||
if not connections:
|
||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||
del USAGE_POOL[model_id]
|
||||
else:
|
||||
USAGE_POOL[model_id] = connections
|
||||
for sid in expired_sids:
|
||||
del connections[sid]
|
||||
|
||||
# Emit updated usage information after cleaning
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
if not connections:
|
||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||
del USAGE_POOL[model_id]
|
||||
else:
|
||||
USAGE_POOL[model_id] = connections
|
||||
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
send_usage = True
|
||||
|
||||
if send_usage:
|
||||
# Emit updated usage information after cleaning
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
finally:
|
||||
release_func()
|
||||
|
||||
|
||||
app = socketio.ASGIApp(
|
||||
|
@ -1,5 +1,33 @@
|
||||
import json
|
||||
import redis
|
||||
import uuid
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs):
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
|
||||
def aquire_lock(self):
|
||||
# nx=True will only set this key if it _hasn't_ already been set
|
||||
self.lock_obtained = self.redis.set(
|
||||
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
|
||||
)
|
||||
return self.lock_obtained
|
||||
|
||||
def renew_lock(self):
|
||||
# xx=True will only set this key if it _has_ already been set
|
||||
return self.redis.set(
|
||||
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
|
||||
)
|
||||
|
||||
def release_lock(self):
|
||||
lock_value = self.redis.get(self.lock_name)
|
||||
if lock_value and lock_value.decode("utf-8") == self.lock_id:
|
||||
self.redis.delete(self.lock_name)
|
||||
|
||||
|
||||
class RedisDict:
|
||||
|
@ -38,13 +38,15 @@
|
||||
let loaded = false;
|
||||
const BREAKPOINT = 768;
|
||||
|
||||
const setupSocket = () => {
|
||||
const setupSocket = (disableWebSocketPolling) => {
|
||||
console.log('Disabled websocket polling', disableWebSocketPolling);
|
||||
const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
|
||||
reconnection: true,
|
||||
reconnectionDelay: 1000,
|
||||
reconnectionDelayMax: 5000,
|
||||
randomizationFactor: 0.5,
|
||||
path: '/ws/socket.io',
|
||||
transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'],
|
||||
auth: { token: localStorage.token }
|
||||
});
|
||||
|
||||
@ -125,8 +127,9 @@
|
||||
await config.set(backendConfig);
|
||||
await WEBUI_NAME.set(backendConfig.name);
|
||||
|
||||
const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true;
|
||||
if ($config) {
|
||||
setupSocket();
|
||||
setupSocket(disableWebSocketPolling);
|
||||
|
||||
if (localStorage.token) {
|
||||
// Get Session User Info
|
||||
|
Loading…
Reference in New Issue
Block a user