open-webui/backend/apps/socket/main.py

148 lines
4.2 KiB
Python
Raw Normal View History

2024-06-04 06:39:52 +00:00
import socketio
2024-06-04 18:13:43 +00:00
import asyncio
2024-06-04 06:39:52 +00:00
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
2024-06-04 08:10:31 +00:00
2024-06-08 00:35:01 +00:00
SESSION_POOL = {}
2024-06-04 06:39:52 +00:00
USER_POOL = {}
2024-06-04 18:13:43 +00:00
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
2024-06-04 06:39:52 +00:00
@sio.event
async def connect(sid, environ, auth):
print("connect ", sid)
user = None
2024-06-04 07:45:56 +00:00
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
2024-06-08 00:35:01 +00:00
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
2024-06-04 07:45:56 +00:00
print(f"user {user.name}({user.id}) connected with session ID {sid}")
2024-06-04 06:39:52 +00:00
2024-06-04 08:10:31 +00:00
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))})
2024-06-04 18:38:31 +00:00
await sio.emit("usage", {"models": get_models_in_use()})
2024-06-04 08:10:31 +00:00
2024-06-04 16:52:27 +00:00
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
2024-06-08 00:35:01 +00:00
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
2024-06-04 16:52:27 +00:00
print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))})
2024-06-04 08:10:31 +00:00
@sio.on("user-count")
async def user_count(sid):
print("user-count", sid)
await sio.emit("user-count", {"count": len(set(USER_POOL))})
2024-06-04 06:39:52 +00:00
2024-06-04 18:13:43 +00:00
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
2024-06-04 18:38:31 +00:00
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
2024-06-04 18:13:43 +00:00
print(f"Models in use: {models_in_use}")
return models_in_use
@sio.on("usage")
async def usage(sid, data):
print(f'Received "usage" event from {sid}: {data}')
model_id = data["model"]
2024-06-04 18:38:31 +00:00
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
2024-06-04 18:13:43 +00:00
2024-06-04 18:38:31 +00:00
# Store the new usage data and task
2024-06-04 18:13:43 +00:00
2024-06-04 18:38:31 +00:00
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
2024-06-04 18:13:43 +00:00
else:
2024-06-04 18:38:31 +00:00
USAGE_POOL[model_id] = {"sids": [sid]}
2024-06-04 18:13:43 +00:00
# Schedule a task to remove the usage data after TIMEOUT_DURATION
2024-06-04 18:38:31 +00:00
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
2024-06-04 18:13:43 +00:00
# Broadcast the usage data to all clients
2024-06-04 18:38:31 +00:00
await sio.emit("usage", {"models": get_models_in_use()})
2024-06-04 18:13:43 +00:00
async def remove_after_timeout(sid, model_id):
try:
2024-06-04 18:38:31 +00:00
print("remove_after_timeout", sid, model_id)
2024-06-04 18:13:43 +00:00
await asyncio.sleep(TIMEOUT_DURATION)
2024-06-04 18:38:31 +00:00
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
print(f"Removed usage data for {model_id} due to timeout")
2024-06-04 18:13:43 +00:00
# Broadcast the usage data to all clients
2024-06-04 18:38:31 +00:00
await sio.emit("usage", {"models": get_models_in_use()})
2024-06-04 18:13:43 +00:00
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
2024-06-04 06:39:52 +00:00
@sio.event
2024-06-04 08:10:31 +00:00
async def disconnect(sid):
2024-06-04 06:39:52 +00:00
if sid in USER_POOL:
2024-06-08 00:35:01 +00:00
disconnected_user = SESSION_POOL.pop(sid)
USER_POOL[disconnected_user].remove(sid)
if len(USER_POOL[disconnected_user]) == 0:
del USER_POOL[disconnected_user]
2024-06-04 06:39:52 +00:00
print(f"user {disconnected_user} disconnected with session ID {sid}")
2024-06-04 08:10:31 +00:00
await sio.emit("user-count", {"count": len(USER_POOL)})
2024-06-04 06:39:52 +00:00
else:
print(f"Unknown session ID {sid} disconnected")