From 5f84145a2d1240623b5ee9202e9b3a80cbc99140 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 22 Sep 2024 02:12:55 +0200 Subject: [PATCH] enh: socket full redis support --- backend/open_webui/apps/socket/main.py | 131 +++++++++++++----------- backend/open_webui/apps/socket/utils.py | 59 +++++++++++ 2 files changed, 128 insertions(+), 62 deletions(-) create mode 100644 backend/open_webui/apps/socket/utils.py diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index c353c8e6f..c0577e648 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -1,6 +1,7 @@ import asyncio - import socketio +import time + from open_webui.apps.webui.models.users import Users from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, @@ -8,6 +9,7 @@ from open_webui.env import ( WEBSOCKET_REDIS_URL, ) from open_webui.utils.utils import decode_token +from open_webui.apps.socket.utils import RedisDict if WEBSOCKET_MANAGER == "redis": @@ -38,13 +40,72 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool -SESSION_POOL = {} -USER_POOL = {} -USAGE_POOL = {} +if WEBSOCKET_MANAGER == "redis": + SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) + USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) + USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) +else: + SESSION_POOL = {} + USER_POOL = {} + USAGE_POOL = {} + + # Timeout duration in seconds TIMEOUT_DURATION = 3 +async def periodic_usage_pool_cleanup(): + while True: + now = int(time.time()) + print("Cleaning up usage pool", now) + for model_id, connections in list(USAGE_POOL.items()): + # Creating a list of sids to remove if they have timed out + expired_sids = [ + sid + for sid, details in connections.items() + if now - details["updated_at"] > TIMEOUT_DURATION + ] + + for sid in expired_sids: + del connections[sid] + + if not connections: + del USAGE_POOL[model_id] + else: + USAGE_POOL[model_id] = connections + + # Emit updated usage information after cleaning + await sio.emit("usage", {"models": get_models_in_use()}) + + await asyncio.sleep(TIMEOUT_DURATION) + + +# Start the cleanup task when your app starts +asyncio.create_task(periodic_usage_pool_cleanup()) + + +def get_models_in_use(): + # List models that are currently in use + models_in_use = list(USAGE_POOL.keys()) + return models_in_use + + +@sio.on("usage") +async def usage(sid, data): + model_id = data["model"] + # Record the timestamp for the last update + current_time = int(time.time()) + + # Store the new usage data and task + USAGE_POOL[model_id] = { + **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}), + sid: {"updated_at": current_time}, + } + + # Broadcast the usage data to all clients + await sio.emit("usage", {"models": get_models_in_use()}) + + @sio.event async def connect(sid, environ, auth): user = None @@ -62,8 +123,7 @@ async def connect(sid, environ, auth): USER_POOL[user.id] = [sid] # print(f"user {user.name}({user.id}) connected with session ID {sid}") - - await sio.emit("user-count", {"count": len(set(USER_POOL))}) + await sio.emit("user-count", {"count": len(USER_POOL.items())}) await sio.emit("usage", {"models": get_models_in_use()}) @@ -91,65 +151,12 @@ async def user_join(sid, data): # print(f"user {user.name}({user.id}) connected with session ID {sid}") - await sio.emit("user-count", {"count": len(set(USER_POOL))}) + await sio.emit("user-count", {"count": len(USER_POOL.items())}) @sio.on("user-count") async def user_count(sid): - await sio.emit("user-count", {"count": len(set(USER_POOL))}) - - -def get_models_in_use(): - # Aggregate all models in use - models_in_use = [] - for model_id, data in USAGE_POOL.items(): - models_in_use.append(model_id) - - return models_in_use - - -@sio.on("usage") -async def usage(sid, data): - model_id = data["model"] - - # Cancel previous callback if there is one - if model_id in USAGE_POOL: - USAGE_POOL[model_id]["callback"].cancel() - - # Store the new usage data and task - - if model_id in USAGE_POOL: - USAGE_POOL[model_id]["sids"].append(sid) - USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"])) - - else: - USAGE_POOL[model_id] = {"sids": [sid]} - - # Schedule a task to remove the usage data after TIMEOUT_DURATION - USAGE_POOL[model_id]["callback"] = asyncio.create_task( - remove_after_timeout(sid, model_id) - ) - - # Broadcast the usage data to all clients - await sio.emit("usage", {"models": get_models_in_use()}) - - -async def remove_after_timeout(sid, model_id): - try: - await asyncio.sleep(TIMEOUT_DURATION) - if model_id in USAGE_POOL: - # print(USAGE_POOL[model_id]["sids"]) - USAGE_POOL[model_id]["sids"].remove(sid) - USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"])) - - if len(USAGE_POOL[model_id]["sids"]) == 0: - del USAGE_POOL[model_id] - - # Broadcast the usage data to all clients - await sio.emit("usage", {"models": get_models_in_use()}) - except asyncio.CancelledError: - # Task was cancelled due to new 'usage' event - pass + await sio.emit("user-count", {"count": len(USER_POOL.items())}) @sio.event @@ -158,7 +165,7 @@ async def disconnect(sid): user_id = SESSION_POOL[sid] del SESSION_POOL[sid] - USER_POOL[user_id].remove(sid) + USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid] if len(USER_POOL[user_id]) == 0: del USER_POOL[user_id] diff --git a/backend/open_webui/apps/socket/utils.py b/backend/open_webui/apps/socket/utils.py new file mode 100644 index 000000000..1862ff439 --- /dev/null +++ b/backend/open_webui/apps/socket/utils.py @@ -0,0 +1,59 @@ +import json +import redis + + +class RedisDict: + def __init__(self, name, redis_url): + self.name = name + self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + + def __setitem__(self, key, value): + serialized_value = json.dumps(value) + self.redis.hset(self.name, key, serialized_value) + + def __getitem__(self, key): + value = self.redis.hget(self.name, key) + if value is None: + raise KeyError(key) + return json.loads(value) + + def __delitem__(self, key): + result = self.redis.hdel(self.name, key) + if result == 0: + raise KeyError(key) + + def __contains__(self, key): + return self.redis.hexists(self.name, key) + + def __len__(self): + return self.redis.hlen(self.name) + + def keys(self): + return self.redis.hkeys(self.name) + + def values(self): + return [json.loads(v) for v in self.redis.hvals(self.name)] + + def items(self): + return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def clear(self): + self.redis.delete(self.name) + + def update(self, other=None, **kwargs): + if other is not None: + for k, v in other.items() if hasattr(other, "items") else other: + self[k] = v + for k, v in kwargs.items(): + self[k] = v + + def setdefault(self, key, default=None): + if key not in self: + self[key] = default + return self[key]