enh: socket full redis support

This commit is contained in:
Timothy J. Baek 2024-09-22 02:12:55 +02:00
parent 47a9395a22
commit 5f84145a2d
2 changed files with 128 additions and 62 deletions

View File

@ -1,6 +1,7 @@
import asyncio
import socketio
import time
from open_webui.apps.webui.models.users import Users
from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
@ -8,6 +9,7 @@ from open_webui.env import (
WEBSOCKET_REDIS_URL,
)
from open_webui.utils.utils import decode_token
from open_webui.apps.socket.utils import RedisDict
if WEBSOCKET_MANAGER == "redis":
@ -38,13 +40,72 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
if WEBSOCKET_MANAGER == "redis":
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
async def periodic_usage_pool_cleanup():
while True:
now = int(time.time())
print("Cleaning up usage pool", now)
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
for sid in expired_sids:
del connections[sid]
if not connections:
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
# Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()})
await asyncio.sleep(TIMEOUT_DURATION)
# Start the cleanup task when your app starts
asyncio.create_task(periodic_usage_pool_cleanup())
def get_models_in_use():
# List models that are currently in use
models_in_use = list(USAGE_POOL.keys())
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Record the timestamp for the last update
current_time = int(time.time())
# Store the new usage data and task
USAGE_POOL[model_id] = {
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
sid: {"updated_at": current_time},
}
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
@sio.event
async def connect(sid, environ, auth):
user = None
@ -62,8 +123,7 @@ async def connect(sid, environ, auth):
USER_POOL[user.id] = [sid]
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("user-count", {"count": len(USER_POOL.items())})
await sio.emit("usage", {"models": get_models_in_use()})
@ -91,65 +151,12 @@ async def user_join(sid, data):
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("user-count", {"count": len(USER_POOL.items())})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
# print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
await sio.emit("user-count", {"count": len(USER_POOL.items())})
@sio.event
@ -158,7 +165,7 @@ async def disconnect(sid):
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]

View File

@ -0,0 +1,59 @@
import json
import redis
class RedisDict:
def __init__(self, name, redis_url):
self.name = name
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
def __setitem__(self, key, value):
serialized_value = json.dumps(value)
self.redis.hset(self.name, key, serialized_value)
def __getitem__(self, key):
value = self.redis.hget(self.name, key)
if value is None:
raise KeyError(key)
return json.loads(value)
def __delitem__(self, key):
result = self.redis.hdel(self.name, key)
if result == 0:
raise KeyError(key)
def __contains__(self, key):
return self.redis.hexists(self.name, key)
def __len__(self):
return self.redis.hlen(self.name)
def keys(self):
return self.redis.hkeys(self.name)
def values(self):
return [json.loads(v) for v in self.redis.hvals(self.name)]
def items(self):
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def clear(self):
self.redis.delete(self.name)
def update(self, other=None, **kwargs):
if other is not None:
for k, v in other.items() if hasattr(other, "items") else other:
self[k] = v
for k, v in kwargs.items():
self[k] = v
def setdefault(self, key, default=None):
if key not in self:
self[key] = default
return self[key]