mirror of
https://github.com/open-webui/open-webui
synced 2025-03-31 15:52:03 +00:00
enh: socket full redis support
This commit is contained in:
parent
47a9395a22
commit
5f84145a2d
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import socketio
|
import socketio
|
||||||
|
import time
|
||||||
|
|
||||||
from open_webui.apps.webui.models.users import Users
|
from open_webui.apps.webui.models.users import Users
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
@ -8,6 +9,7 @@ from open_webui.env import (
|
|||||||
WEBSOCKET_REDIS_URL,
|
WEBSOCKET_REDIS_URL,
|
||||||
)
|
)
|
||||||
from open_webui.utils.utils import decode_token
|
from open_webui.utils.utils import decode_token
|
||||||
|
from open_webui.apps.socket.utils import RedisDict
|
||||||
|
|
||||||
|
|
||||||
if WEBSOCKET_MANAGER == "redis":
|
if WEBSOCKET_MANAGER == "redis":
|
||||||
@ -38,13 +40,72 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
|
|||||||
|
|
||||||
# Dictionary to maintain the user pool
|
# Dictionary to maintain the user pool
|
||||||
|
|
||||||
SESSION_POOL = {}
|
if WEBSOCKET_MANAGER == "redis":
|
||||||
USER_POOL = {}
|
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
USAGE_POOL = {}
|
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
|
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||||
|
else:
|
||||||
|
SESSION_POOL = {}
|
||||||
|
USER_POOL = {}
|
||||||
|
USAGE_POOL = {}
|
||||||
|
|
||||||
|
|
||||||
# Timeout duration in seconds
|
# Timeout duration in seconds
|
||||||
TIMEOUT_DURATION = 3
|
TIMEOUT_DURATION = 3
|
||||||
|
|
||||||
|
|
||||||
|
async def periodic_usage_pool_cleanup():
|
||||||
|
while True:
|
||||||
|
now = int(time.time())
|
||||||
|
print("Cleaning up usage pool", now)
|
||||||
|
for model_id, connections in list(USAGE_POOL.items()):
|
||||||
|
# Creating a list of sids to remove if they have timed out
|
||||||
|
expired_sids = [
|
||||||
|
sid
|
||||||
|
for sid, details in connections.items()
|
||||||
|
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||||
|
]
|
||||||
|
|
||||||
|
for sid in expired_sids:
|
||||||
|
del connections[sid]
|
||||||
|
|
||||||
|
if not connections:
|
||||||
|
del USAGE_POOL[model_id]
|
||||||
|
else:
|
||||||
|
USAGE_POOL[model_id] = connections
|
||||||
|
|
||||||
|
# Emit updated usage information after cleaning
|
||||||
|
await sio.emit("usage", {"models": get_models_in_use()})
|
||||||
|
|
||||||
|
await asyncio.sleep(TIMEOUT_DURATION)
|
||||||
|
|
||||||
|
|
||||||
|
# Start the cleanup task when your app starts
|
||||||
|
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||||
|
|
||||||
|
|
||||||
|
def get_models_in_use():
|
||||||
|
# List models that are currently in use
|
||||||
|
models_in_use = list(USAGE_POOL.keys())
|
||||||
|
return models_in_use
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("usage")
|
||||||
|
async def usage(sid, data):
|
||||||
|
model_id = data["model"]
|
||||||
|
# Record the timestamp for the last update
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
# Store the new usage data and task
|
||||||
|
USAGE_POOL[model_id] = {
|
||||||
|
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
|
||||||
|
sid: {"updated_at": current_time},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Broadcast the usage data to all clients
|
||||||
|
await sio.emit("usage", {"models": get_models_in_use()})
|
||||||
|
|
||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
async def connect(sid, environ, auth):
|
async def connect(sid, environ, auth):
|
||||||
user = None
|
user = None
|
||||||
@ -62,8 +123,7 @@ async def connect(sid, environ, auth):
|
|||||||
USER_POOL[user.id] = [sid]
|
USER_POOL[user.id] = [sid]
|
||||||
|
|
||||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||||
|
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
|
||||||
await sio.emit("usage", {"models": get_models_in_use()})
|
await sio.emit("usage", {"models": get_models_in_use()})
|
||||||
|
|
||||||
|
|
||||||
@ -91,65 +151,12 @@ async def user_join(sid, data):
|
|||||||
|
|
||||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||||
|
|
||||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||||
|
|
||||||
|
|
||||||
@sio.on("user-count")
|
@sio.on("user-count")
|
||||||
async def user_count(sid):
|
async def user_count(sid):
|
||||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||||
|
|
||||||
|
|
||||||
def get_models_in_use():
|
|
||||||
# Aggregate all models in use
|
|
||||||
models_in_use = []
|
|
||||||
for model_id, data in USAGE_POOL.items():
|
|
||||||
models_in_use.append(model_id)
|
|
||||||
|
|
||||||
return models_in_use
|
|
||||||
|
|
||||||
|
|
||||||
@sio.on("usage")
|
|
||||||
async def usage(sid, data):
|
|
||||||
model_id = data["model"]
|
|
||||||
|
|
||||||
# Cancel previous callback if there is one
|
|
||||||
if model_id in USAGE_POOL:
|
|
||||||
USAGE_POOL[model_id]["callback"].cancel()
|
|
||||||
|
|
||||||
# Store the new usage data and task
|
|
||||||
|
|
||||||
if model_id in USAGE_POOL:
|
|
||||||
USAGE_POOL[model_id]["sids"].append(sid)
|
|
||||||
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
|
||||||
|
|
||||||
else:
|
|
||||||
USAGE_POOL[model_id] = {"sids": [sid]}
|
|
||||||
|
|
||||||
# Schedule a task to remove the usage data after TIMEOUT_DURATION
|
|
||||||
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
|
|
||||||
remove_after_timeout(sid, model_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the usage data to all clients
|
|
||||||
await sio.emit("usage", {"models": get_models_in_use()})
|
|
||||||
|
|
||||||
|
|
||||||
async def remove_after_timeout(sid, model_id):
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(TIMEOUT_DURATION)
|
|
||||||
if model_id in USAGE_POOL:
|
|
||||||
# print(USAGE_POOL[model_id]["sids"])
|
|
||||||
USAGE_POOL[model_id]["sids"].remove(sid)
|
|
||||||
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
|
||||||
|
|
||||||
if len(USAGE_POOL[model_id]["sids"]) == 0:
|
|
||||||
del USAGE_POOL[model_id]
|
|
||||||
|
|
||||||
# Broadcast the usage data to all clients
|
|
||||||
await sio.emit("usage", {"models": get_models_in_use()})
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Task was cancelled due to new 'usage' event
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
@ -158,7 +165,7 @@ async def disconnect(sid):
|
|||||||
user_id = SESSION_POOL[sid]
|
user_id = SESSION_POOL[sid]
|
||||||
del SESSION_POOL[sid]
|
del SESSION_POOL[sid]
|
||||||
|
|
||||||
USER_POOL[user_id].remove(sid)
|
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
||||||
|
|
||||||
if len(USER_POOL[user_id]) == 0:
|
if len(USER_POOL[user_id]) == 0:
|
||||||
del USER_POOL[user_id]
|
del USER_POOL[user_id]
|
||||||
|
59
backend/open_webui/apps/socket/utils.py
Normal file
59
backend/open_webui/apps/socket/utils.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import json
|
||||||
|
import redis
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDict:
|
||||||
|
def __init__(self, name, redis_url):
|
||||||
|
self.name = name
|
||||||
|
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
serialized_value = json.dumps(value)
|
||||||
|
self.redis.hset(self.name, key, serialized_value)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
value = self.redis.hget(self.name, key)
|
||||||
|
if value is None:
|
||||||
|
raise KeyError(key)
|
||||||
|
return json.loads(value)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
result = self.redis.hdel(self.name, key)
|
||||||
|
if result == 0:
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self.redis.hexists(self.name, key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.redis.hlen(self.name)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.redis.hkeys(self.name)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return [json.loads(v) for v in self.redis.hvals(self.name)]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.redis.delete(self.name)
|
||||||
|
|
||||||
|
def update(self, other=None, **kwargs):
|
||||||
|
if other is not None:
|
||||||
|
for k, v in other.items() if hasattr(other, "items") else other:
|
||||||
|
self[k] = v
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def setdefault(self, key, default=None):
|
||||||
|
if key not in self:
|
||||||
|
self[key] = default
|
||||||
|
return self[key]
|
Loading…
Reference in New Issue
Block a user