From f1d21fc59a52bd7ca117f73411fc65231f02f6eb Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 22 Dec 2024 19:40:01 -0700 Subject: [PATCH] feat: channel socket integration --- backend/open_webui/models/channels.py | 14 +- backend/open_webui/models/messages.py | 4 +- backend/open_webui/routers/channels.py | 28 +- backend/open_webui/socket/main.py | 8 +- src/lib/apis/channels/index.ts | 71 +++++ src/lib/components/channel/Channel.svelte | 93 +++++- .../components/channel/MessageInput.svelte | 266 ++++++++++++++++++ src/lib/components/channel/Messages.svelte | 27 +- .../channel/Messages/Message.svelte | 13 + src/lib/components/chat/MessageInput.svelte | 2 +- src/lib/components/common/Image.svelte | 1 + src/lib/components/layout/Sidebar.svelte | 5 +- src/routes/+layout.svelte | 8 +- 13 files changed, 509 insertions(+), 31 deletions(-) create mode 100644 src/lib/components/channel/MessageInput.svelte diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 0e31d5e8e..cc49953e7 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -4,8 +4,7 @@ import uuid from typing import Optional from open_webui.internal.db import Base, get_db -from open_webui.models.tags import TagModel, Tag, Tags - +from open_webui.utils.access_control import has_access from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON @@ -85,6 +84,17 @@ class ChannelTable: channels = db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] + def get_channels_by_user_id( + self, user_id: str, permission: str = "read" + ) -> list[ChannelModel]: + channels = self.get_channels() + return [ + channel + for channel in channels + if channel.user_id == user_id + or has_access(user_id, permission, channel.access_control) + ] + def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: with get_db() as db: channel = db.query(Channel).filter(Channel.id == id).first() diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index c9161da96..8e4306cd5 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -95,7 +95,7 @@ class MessageTable: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id) - .order_by(Message.updated_at.desc()) + .order_by(Message.updated_at.asc()) .limit(limit) .offset(skip) .all() @@ -109,7 +109,7 @@ class MessageTable: all_messages = ( db.query(Message) .filter_by(user_id=user_id) - .order_by(Message.updated_at.desc()) + .order_by(Message.updated_at.asc()) .limit(limit) .offset(skip) .all() diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index b4b458f25..c73c601f7 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -2,6 +2,12 @@ import json import logging from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel + + +from open_webui.socket.main import sio from open_webui.models.channels import Channels, ChannelModel, ChannelForm from open_webui.models.messages import Messages, MessageModel, MessageForm @@ -9,12 +15,10 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -53,7 +57,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user ############################ -@router.post("/{id}/messages", response_model=list[MessageModel]) +@router.get("/{id}/messages", response_model=list[MessageModel]) async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)): channel = Channels.get_channel_by_id(id) if not channel: @@ -61,7 +65,7 @@ async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if not has_permission(channel.access_control, user): + if not has_access(user.id, type="read", access_control=channel.access_control): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -87,13 +91,25 @@ async def post_new_message( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if not has_permission(channel.access_control, user): + if not has_access(user.id, type="read", access_control=channel.access_control): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) try: message = Messages.insert_new_message(form_data, channel.id, user.id) + + if message: + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": message.id, + "data": {"message": message.model_dump()}, + }, + to=f"channel:{channel.id}", + ) + return MessageModel(**message.model_dump()) except Exception as e: log.exception(e) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 965fb9396..23be163e9 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -5,6 +5,7 @@ import sys import time from open_webui.models.users import Users +from open_webui.models.channels import Channels from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, @@ -162,7 +163,6 @@ async def connect(sid, environ, auth): @sio.on("user-join") async def user_join(sid, data): - # print("user-join", sid, data) auth = data["auth"] if "auth" in data else None if not auth or "token" not in auth: @@ -182,6 +182,12 @@ async def user_join(sid, data): else: USER_POOL[user.id] = [sid] + # Join all the channels + channels = Channels.get_channels_by_user_id(user.id) + log.debug(f"{channels=}") + for channel in channels: + await sio.enter_room(sid, f"channel:{channel.id}") + # print(f"user {user.name}({user.id}) connected with session ID {sid}") await sio.emit("user-count", {"count": len(USER_POOL.items())}) diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 8fd6f24f1..84f372fa0 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -69,3 +69,74 @@ export const getChannels = async (token: string = '') => { return res; }; + + +export const getChannelMessages = async (token: string = '', channel_id: string, page: number = 1) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages?page=${page}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +} + +type MessageForm = { + content: string; + data?: object; + meta?: object; + +} + +export const sendMessage = async (token: string = '', channel_id: string, message: MessageForm) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/post`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ ...message }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +} \ No newline at end of file diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 01e35b802..ca5cbda67 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -1,5 +1,96 @@ -{id} +
+ + +
+ +
+
diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte new file mode 100644 index 000000000..be987f1cc --- /dev/null +++ b/src/lib/components/channel/MessageInput.svelte @@ -0,0 +1,266 @@ + + +
+
+
+ {#if recording} + { + recording = false; + + await tick(); + document.getElementById('chat-input')?.focus(); + }} + on:confirm={async (e) => { + const { text, filename } = e.detail; + content = `${content}${text} `; + recording = false; + + await tick(); + document.getElementById('chat-input')?.focus(); + }} + /> + {:else} +
{ + submitHandler(); + }} + > +
+
+
+ +
+ + {#if $settings?.richTextInput ?? true} +
+ 0 || + navigator.msMaxTouchPoints > 0 + )} + {placeholder} + largeTextAsFile={$settings?.largeTextAsFile ?? false} + on:keydown={async (e) => { + e = e.detail.event; + const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac + if ( + !$mobile || + !( + 'ontouchstart' in window || + navigator.maxTouchPoints > 0 || + navigator.msMaxTouchPoints > 0 + ) + ) { + // Prevent Enter key from creating a new line + // Uses keyCode '13' for Enter key for chinese/japanese keyboards + if (e.keyCode === 13 && !e.shiftKey) { + e.preventDefault(); + } + + // Submit the content when Enter key is pressed + if (content !== '' && e.keyCode === 13 && !e.shiftKey) { + submitHandler(); + } + } + + if (e.key === 'Escape') { + console.log('Escape'); + } + }} + on:paste={async (e) => { + e = e.detail.event; + console.log(e); + }} + /> +
+ {:else} +