mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 23:02:25 +00:00
feat: Make ENABLE_WEBSOCKET_SUPPORT disable polling entirely to allow multiple replicas without sticky sessions.
See https://socket.io/docs/v4/using-multiple-nodes/ for details why this was done. Also create a redis key to track which replica is running the cleanup job
This commit is contained in:
parent
a38934bd23
commit
8f51681801
@ -264,6 +264,7 @@ from open_webui.env import (
|
|||||||
WEBUI_SESSION_COOKIE_SECURE,
|
WEBUI_SESSION_COOKIE_SECURE,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||||
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
RESET_CONFIG_ON_START,
|
RESET_CONFIG_ON_START,
|
||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
@ -932,6 +933,7 @@ async def get_app_config(request: Request):
|
|||||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||||
|
"disable_websocket_polling": ENABLE_WEBSOCKET_SUPPORT,
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||||
|
@ -11,7 +11,7 @@ from open_webui.env import (
|
|||||||
WEBSOCKET_REDIS_URL,
|
WEBSOCKET_REDIS_URL,
|
||||||
)
|
)
|
||||||
from open_webui.utils.auth import decode_token
|
from open_webui.utils.auth import decode_token
|
||||||
from open_webui.socket.utils import RedisDict
|
from open_webui.socket.utils import RedisDict, RedisLock
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||||||
sio = socketio.AsyncServer(
|
sio = socketio.AsyncServer(
|
||||||
cors_allowed_origins=[],
|
cors_allowed_origins=[],
|
||||||
async_mode="asgi",
|
async_mode="asgi",
|
||||||
transports=(
|
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
|
||||||
),
|
|
||||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||||
always_connect=True,
|
always_connect=True,
|
||||||
client_manager=mgr,
|
client_manager=mgr,
|
||||||
@ -40,54 +38,78 @@ else:
|
|||||||
sio = socketio.AsyncServer(
|
sio = socketio.AsyncServer(
|
||||||
cors_allowed_origins=[],
|
cors_allowed_origins=[],
|
||||||
async_mode="asgi",
|
async_mode="asgi",
|
||||||
transports=(
|
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
|
||||||
),
|
|
||||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||||
always_connect=True,
|
always_connect=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Timeout duration in seconds
|
||||||
|
TIMEOUT_DURATION = 3
|
||||||
|
|
||||||
# Dictionary to maintain the user pool
|
# Dictionary to maintain the user pool
|
||||||
|
|
||||||
|
run_cleanup = True
|
||||||
if WEBSOCKET_MANAGER == "redis":
|
if WEBSOCKET_MANAGER == "redis":
|
||||||
|
log.debug("Using Redis to manage websockets.")
|
||||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
|
|
||||||
|
clean_up_lock = RedisLock(
|
||||||
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
|
lock_name="usage_cleanup_lock",
|
||||||
|
timeout_secs=TIMEOUT_DURATION * 2,
|
||||||
|
)
|
||||||
|
run_cleanup = clean_up_lock.aquire_lock()
|
||||||
|
renew_func = clean_up_lock.renew_lock
|
||||||
|
release_func = clean_up_lock.release_lock
|
||||||
else:
|
else:
|
||||||
SESSION_POOL = {}
|
SESSION_POOL = {}
|
||||||
USER_POOL = {}
|
USER_POOL = {}
|
||||||
USAGE_POOL = {}
|
USAGE_POOL = {}
|
||||||
|
release_func = renew_func = lambda: True
|
||||||
|
|
||||||
# Timeout duration in seconds
|
|
||||||
TIMEOUT_DURATION = 3
|
|
||||||
|
|
||||||
|
|
||||||
async def periodic_usage_pool_cleanup():
|
async def periodic_usage_pool_cleanup():
|
||||||
while True:
|
if not run_cleanup:
|
||||||
now = int(time.time())
|
log.debug("Usage pool cleanup lock already exists. Not running it.")
|
||||||
for model_id, connections in list(USAGE_POOL.items()):
|
return
|
||||||
# Creating a list of sids to remove if they have timed out
|
log.debug("Running periodic_usage_pool_cleanup")
|
||||||
expired_sids = [
|
try:
|
||||||
sid
|
while True:
|
||||||
for sid, details in connections.items()
|
if not renew_func():
|
||||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
|
||||||
]
|
raise Exception("Unable to renew usage pool cleanup lock.")
|
||||||
|
|
||||||
for sid in expired_sids:
|
now = int(time.time())
|
||||||
del connections[sid]
|
send_usage = False
|
||||||
|
for model_id, connections in list(USAGE_POOL.items()):
|
||||||
|
# Creating a list of sids to remove if they have timed out
|
||||||
|
expired_sids = [
|
||||||
|
sid
|
||||||
|
for sid, details in connections.items()
|
||||||
|
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||||
|
]
|
||||||
|
|
||||||
if not connections:
|
for sid in expired_sids:
|
||||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
del connections[sid]
|
||||||
del USAGE_POOL[model_id]
|
|
||||||
else:
|
|
||||||
USAGE_POOL[model_id] = connections
|
|
||||||
|
|
||||||
# Emit updated usage information after cleaning
|
if not connections:
|
||||||
await sio.emit("usage", {"models": get_models_in_use()})
|
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||||
|
del USAGE_POOL[model_id]
|
||||||
|
else:
|
||||||
|
USAGE_POOL[model_id] = connections
|
||||||
|
|
||||||
await asyncio.sleep(TIMEOUT_DURATION)
|
send_usage = True
|
||||||
|
|
||||||
|
if send_usage:
|
||||||
|
# Emit updated usage information after cleaning
|
||||||
|
await sio.emit("usage", {"models": get_models_in_use()})
|
||||||
|
|
||||||
|
await asyncio.sleep(TIMEOUT_DURATION)
|
||||||
|
finally:
|
||||||
|
release_func()
|
||||||
|
|
||||||
|
|
||||||
app = socketio.ASGIApp(
|
app = socketio.ASGIApp(
|
||||||
|
@ -1,5 +1,33 @@
|
|||||||
import json
|
import json
|
||||||
import redis
|
import redis
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class RedisLock:
|
||||||
|
def __init__(self, redis_url, lock_name, timeout_secs):
|
||||||
|
self.lock_name = lock_name
|
||||||
|
self.lock_id = str(uuid.uuid4())
|
||||||
|
self.timeout_secs = timeout_secs
|
||||||
|
self.lock_obtained = False
|
||||||
|
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||||
|
|
||||||
|
def aquire_lock(self):
|
||||||
|
# nx=True will only set this key if it _hasn't_ already been set
|
||||||
|
self.lock_obtained = self.redis.set(
|
||||||
|
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
|
||||||
|
)
|
||||||
|
return self.lock_obtained
|
||||||
|
|
||||||
|
def renew_lock(self):
|
||||||
|
# xx=True will only set this key if it _has_ already been set
|
||||||
|
return self.redis.set(
|
||||||
|
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
|
||||||
|
)
|
||||||
|
|
||||||
|
def release_lock(self):
|
||||||
|
lock_value = self.redis.get(self.lock_name)
|
||||||
|
if lock_value and lock_value.decode("utf-8") == self.lock_id:
|
||||||
|
self.redis.delete(self.lock_name)
|
||||||
|
|
||||||
|
|
||||||
class RedisDict:
|
class RedisDict:
|
||||||
|
@ -38,13 +38,15 @@
|
|||||||
let loaded = false;
|
let loaded = false;
|
||||||
const BREAKPOINT = 768;
|
const BREAKPOINT = 768;
|
||||||
|
|
||||||
const setupSocket = () => {
|
const setupSocket = (disableWebSocketPolling) => {
|
||||||
|
console.log('Disabled websocket polling', disableWebSocketPolling);
|
||||||
const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
|
const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
|
||||||
reconnection: true,
|
reconnection: true,
|
||||||
reconnectionDelay: 1000,
|
reconnectionDelay: 1000,
|
||||||
reconnectionDelayMax: 5000,
|
reconnectionDelayMax: 5000,
|
||||||
randomizationFactor: 0.5,
|
randomizationFactor: 0.5,
|
||||||
path: '/ws/socket.io',
|
path: '/ws/socket.io',
|
||||||
|
transports: disableWebSocketPolling ? ['websocket'] : ['polling', 'websocket'],
|
||||||
auth: { token: localStorage.token }
|
auth: { token: localStorage.token }
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -125,8 +127,9 @@
|
|||||||
await config.set(backendConfig);
|
await config.set(backendConfig);
|
||||||
await WEBUI_NAME.set(backendConfig.name);
|
await WEBUI_NAME.set(backendConfig.name);
|
||||||
|
|
||||||
|
const disableWebSocketPolling = backendConfig.features.disable_websocket_polling === true;
|
||||||
if ($config) {
|
if ($config) {
|
||||||
setupSocket();
|
setupSocket(disableWebSocketPolling);
|
||||||
|
|
||||||
if (localStorage.token) {
|
if (localStorage.token) {
|
||||||
// Get Session User Info
|
// Get Session User Info
|
||||||
|
Loading…
Reference in New Issue
Block a user