From 85484392b2a646dca66926f7181d5e98c5d0d59a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 3 Jun 2024 23:39:52 -0700 Subject: [PATCH] feat: websocket --- backend/apps/socket/main.py | 47 ++++++++++++++++++++++ backend/main.py | 5 +++ package-lock.json | 80 +++++++++++++++++++++++++++++++++++++ package.json | 1 + src/lib/constants.ts | 3 +- src/lib/stores/index.ts | 3 ++ src/routes/+layout.svelte | 18 +++++++-- 7 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 backend/apps/socket/main.py diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py new file mode 100644 index 000000000..0cf0b5fc1 --- /dev/null +++ b/backend/apps/socket/main.py @@ -0,0 +1,47 @@ +import socketio + +from apps.webui.models.users import Users +from utils.utils import decode_token + +sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi") +app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") + +# Dictionary to maintain the user pool +USER_POOL = {} + + +@sio.event +async def connect(sid, environ, auth): + print("connect ", sid) + + user = None + data = decode_token(auth["token"]) + + if data is not None and "id" in data: + user = Users.get_user_by_id(data["id"]) + + if user: + USER_POOL[sid] = { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + } + print(f"user {user.name}({user.id}) connected with session ID {sid}") + else: + print("Authentication failed. Disconnecting.") + await sio.disconnect(sid) + + +@sio.event +def disconnect(sid): + if sid in USER_POOL: + disconnected_user = USER_POOL.pop(sid) + print(f"user {disconnected_user} disconnected with session ID {sid}") + else: + print(f"Unknown session ID {sid} disconnected") + + +@sio.event +def disconnect(sid): + print("disconnect", sid) diff --git a/backend/main.py b/backend/main.py index 4e9d1adf9..4ab13e98f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,6 +20,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse, Response + +from apps.socket.main import app as socket_app from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models from apps.openai.main import app as openai_app, get_all_models as get_openai_models @@ -376,6 +378,9 @@ async def update_embedding_function(request: Request, call_next): return response +app.mount("/ws", socket_app) + + app.mount("/ollama", ollama_app) app.mount("/openai", openai_app) diff --git a/package-lock.json b/package-lock.json index 097133321..3066e6581 100644 --- a/package-lock.json +++ b/package-lock.json @@ -25,6 +25,7 @@ "marked": "^9.1.0", "mermaid": "^10.9.1", "pyodide": "^0.26.0-alpha.4", + "socket.io-client": "^4.7.5", "sortablejs": "^1.15.2", "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", @@ -1214,6 +1215,11 @@ "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", "dev": true }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==" + }, "node_modules/@sveltejs/adapter-auto": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/@sveltejs/adapter-auto/-/adapter-auto-2.1.1.tgz", @@ -3800,6 +3806,46 @@ "once": "^1.4.0" } }, + "node_modules/engine.io-client": { + "version": "6.5.3", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz", + "integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.2.1", + "ws": "~8.11.0", + "xmlhttprequest-ssl": "~2.0.0" + } + }, + "node_modules/engine.io-client/node_modules/ws": { + "version": "8.11.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", + "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/engine.io-parser": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz", + "integrity": "sha512-RcyUFKA93/CXH20l4SoVvzZfrSDMOTUS3bWVpTt2FuFP+XYrL8i8oonHP7WInRyVHXh0n/ORtoeiE1os+8qkSw==", + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/enquirer": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/enquirer/-/enquirer-2.4.1.tgz", @@ -7949,6 +7995,32 @@ "node": ">=8" } }, + "node_modules/socket.io-client": { + "version": "4.7.5", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.7.5.tgz", + "integrity": "sha512-sJ/tqHOCe7Z50JCBCXrsY3I2k03iOiUe+tj1OmKeD2lXPiGH/RUCdTZFoqVyN7l1MnpIzPrGtLcijffmeouNlQ==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.5.2", + "socket.io-parser": "~4.2.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/sorcery": { "version": "0.11.0", "resolved": "https://registry.npmjs.org/sorcery/-/sorcery-0.11.0.tgz", @@ -10142,6 +10214,14 @@ } } }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz", + "integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/xtend": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", diff --git a/package.json b/package.json index a54d75d8a..a13e8f96b 100644 --- a/package.json +++ b/package.json @@ -65,6 +65,7 @@ "marked": "^9.1.0", "mermaid": "^10.9.1", "pyodide": "^0.26.0-alpha.4", + "socket.io-client": "^4.7.5", "sortablejs": "^1.15.2", "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", diff --git a/src/lib/constants.ts b/src/lib/constants.ts index 5ae74448f..163309802 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -2,8 +2,9 @@ import { browser, dev } from '$app/environment'; // import { version } from '../../package.json'; export const APP_NAME = 'Open WebUI'; -export const WEBUI_BASE_URL = browser ? (dev ? `http://${location.hostname}:8080` : ``) : ``; +export const WEBUI_HOSTNAME = browser ? (dev ? `${location.hostname}:8080` : ``) : ''; +export const WEBUI_BASE_URL = browser ? (dev ? `http://${WEBUI_HOSTNAME}` : ``) : ``; export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama`; diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index 5616045f6..619b175d9 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -2,6 +2,7 @@ import { APP_NAME } from '$lib/constants'; import { type Writable, writable } from 'svelte/store'; import type { GlobalModelConfig, ModelConfig } from '$lib/apis'; import type { Banner } from '$lib/types'; +import type { Socket } from 'socket.io-client'; // Backend export const WEBUI_NAME = writable(APP_NAME); @@ -13,6 +14,8 @@ export const MODEL_DOWNLOAD_POOL = writable({}); export const mobile = writable(false); +export const socket: Writable = writable(null); + export const theme = writable('system'); export const chatId = writable(''); diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index c0ede634f..ccccccec3 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -1,6 +1,8 @@