feat: channel socket integration

This commit is contained in:
Timothy Jaeryang Baek 2024-12-22 19:40:01 -07:00
parent eaecd15e69
commit f1d21fc59a
13 changed files with 509 additions and 31 deletions

View File

@ -4,8 +4,7 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@ -85,6 +84,17 @@ class ChannelTable:
channels = db.query(Channel).all() channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_user_id(
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db: with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first() channel = db.query(Channel).filter(Channel.id == id).first()

View File

@ -95,7 +95,7 @@ class MessageTable:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(channel_id=channel_id) .filter_by(channel_id=channel_id)
.order_by(Message.updated_at.desc()) .order_by(Message.updated_at.asc())
.limit(limit) .limit(limit)
.offset(skip) .offset(skip)
.all() .all()
@ -109,7 +109,7 @@ class MessageTable:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(Message.updated_at.desc()) .order_by(Message.updated_at.asc())
.limit(limit) .limit(limit)
.offset(skip) .offset(skip)
.all() .all()

View File

@ -2,6 +2,12 @@ import json
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.socket.main import sio
from open_webui.models.channels import Channels, ChannelModel, ChannelForm from open_webui.models.channels import Channels, ChannelModel, ChannelForm
from open_webui.models.messages import Messages, MessageModel, MessageForm from open_webui.models.messages import Messages, MessageModel, MessageForm
@ -9,12 +15,10 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -53,7 +57,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
############################ ############################
@router.post("/{id}/messages", response_model=list[MessageModel]) @router.get("/{id}/messages", response_model=list[MessageModel])
async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)): async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id) channel = Channels.get_channel_by_id(id)
if not channel: if not channel:
@ -61,7 +65,7 @@ async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if not has_permission(channel.access_control, user): if not has_access(user.id, type="read", access_control=channel.access_control):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
) )
@ -87,13 +91,25 @@ async def post_new_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if not has_permission(channel.access_control, user): if not has_access(user.id, type="read", access_control=channel.access_control):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
) )
try: try:
message = Messages.insert_new_message(form_data, channel.id, user.id) message = Messages.insert_new_message(form_data, channel.id, user.id)
if message:
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {"message": message.model_dump()},
},
to=f"channel:{channel.id}",
)
return MessageModel(**message.model_dump()) return MessageModel(**message.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View File

@ -5,6 +5,7 @@ import sys
import time import time
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.channels import Channels
from open_webui.env import ( from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER, WEBSOCKET_MANAGER,
@ -162,7 +163,6 @@ async def connect(sid, environ, auth):
@sio.on("user-join") @sio.on("user-join")
async def user_join(sid, data): async def user_join(sid, data):
# print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth: if not auth or "token" not in auth:
@ -182,6 +182,12 @@ async def user_join(sid, data):
else: else:
USER_POOL[user.id] = [sid] USER_POOL[user.id] = [sid]
# Join all the channels
channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}")
for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}")
# print(f"user {user.name}({user.id}) connected with session ID {sid}") # print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(USER_POOL.items())}) await sio.emit("user-count", {"count": len(USER_POOL.items())})

View File

@ -69,3 +69,74 @@ export const getChannels = async (token: string = '') => {
return res; return res;
}; };
export const getChannelMessages = async (token: string = '', channel_id: string, page: number = 1) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages?page=${page}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
}
type MessageForm = {
content: string;
data?: object;
meta?: object;
}
export const sendMessage = async (token: string = '', channel_id: string, message: MessageForm) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/post`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({ ...message })
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
}

View File

@ -1,5 +1,96 @@
<script lang="ts"> <script lang="ts">
import { getChannelMessages, sendMessage } from '$lib/apis/channels';
import { toast } from 'svelte-sonner';
import MessageInput from './MessageInput.svelte';
import Messages from './Messages.svelte';
import { socket } from '$lib/stores';
import { onDestroy, onMount, tick } from 'svelte';
export let id = ''; export let id = '';
let scrollEnd = true;
let messagesContainerElement = null;
let top = false;
let page = 1;
let messages = null;
$: if (id) {
initHandler();
}
const initHandler = async () => {
top = false;
page = 1;
messages = null;
messages = await getChannelMessages(localStorage.token, id, page);
if (messages.length < 50) {
top = true;
}
};
const channelEventHandler = async (data) => {
console.log(data);
};
const submitHandler = async ({ content }) => {
if (!content) {
return;
}
const res = await sendMessage(localStorage.token, id, { content: content }).catch((error) => {
toast.error(error);
return null;
});
if (res) {
messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
}
};
onMount(() => {
$socket?.on('channel-events', channelEventHandler);
});
onDestroy(() => {
$socket?.off('channel-events', channelEventHandler);
});
</script> </script>
{id} <div class="h-full md:max-w-[calc(100%-260px)] w-full max-w-full flex flex-col">
<div
class=" pb-2.5 flex flex-col justify-between w-full flex-auto overflow-auto h-0 max-w-full z-10 scrollbar-hidden"
id="messages-container"
bind:this={messagesContainerElement}
on:scroll={(e) => {
scrollEnd =
messagesContainerElement.scrollHeight - messagesContainerElement.scrollTop <=
messagesContainerElement.clientHeight + 5;
}}
>
{#key id}
<Messages
{messages}
onLoad={async () => {
page += 1;
const newMessages = await getChannelMessages(localStorage.token, id, page);
if (newMessages.length === 0) {
top = true;
return;
}
messages = [...newMessages, ...messages];
}}
/>
{/key}
</div>
<div class=" pb-[1rem]">
<MessageInput onSubmit={submitHandler} />
</div>
</div>

View File

@ -0,0 +1,266 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import { tick, getContext } from 'svelte';
const i18n = getContext('i18n');
import { mobile, settings } from '$lib/stores';
import Tooltip from '../common/Tooltip.svelte';
import RichTextInput from '../common/RichTextInput.svelte';
import VoiceRecording from '../chat/MessageInput/VoiceRecording.svelte';
export let placeholder = $i18n.t('Send a Message');
export let transparentBackground = false;
let recording = false;
let content = '';
export let onSubmit: Function;
let submitHandler = async () => {
if (content === '') {
return;
}
onSubmit({
content
});
content = '';
await tick();
const chatInputElement = document.getElementById('chat-input');
chatInputElement?.focus();
};
</script>
<div class="{transparentBackground ? 'bg-transparent' : 'bg-white dark:bg-gray-900'} ">
<div class="max-w-6xl px-2.5 mx-auto inset-x-0">
<div class="">
{#if recording}
<VoiceRecording
bind:recording
on:cancel={async () => {
recording = false;
await tick();
document.getElementById('chat-input')?.focus();
}}
on:confirm={async (e) => {
const { text, filename } = e.detail;
content = `${content}${text} `;
recording = false;
await tick();
document.getElementById('chat-input')?.focus();
}}
/>
{:else}
<form
class="w-full flex gap-1.5"
on:submit|preventDefault={() => {
submitHandler();
}}
>
<div
class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-50 dark:bg-gray-400/5 dark:text-gray-100"
dir={$settings?.chatDirection ?? 'LTR'}
>
<div class=" flex">
<div class="ml-1 self-end mb-1.5 flex space-x-1">
<button
class="bg-transparent hover:bg-white/80 text-gray-800 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-2 outline-none focus:outline-none"
type="button"
aria-label="More"
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="size-5"
>
<path
d="M10.75 4.75a.75.75 0 0 0-1.5 0v4.5h-4.5a.75.75 0 0 0 0 1.5h4.5v4.5a.75.75 0 0 0 1.5 0v-4.5h4.5a.75.75 0 0 0 0-1.5h-4.5v-4.5Z"
/>
</svg>
</button>
</div>
{#if $settings?.richTextInput ?? true}
<div
class="scrollbar-hidden text-left bg-transparent dark:text-gray-100 outline-none w-full py-2.5 px-1 rounded-xl resize-none h-fit max-h-80 overflow-auto"
>
<RichTextInput
bind:value={content}
id="chat-input"
messageInput={true}
shiftEnter={!$mobile ||
!(
'ontouchstart' in window ||
navigator.maxTouchPoints > 0 ||
navigator.msMaxTouchPoints > 0
)}
{placeholder}
largeTextAsFile={$settings?.largeTextAsFile ?? false}
on:keydown={async (e) => {
e = e.detail.event;
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
if (
!$mobile ||
!(
'ontouchstart' in window ||
navigator.maxTouchPoints > 0 ||
navigator.msMaxTouchPoints > 0
)
) {
// Prevent Enter key from creating a new line
// Uses keyCode '13' for Enter key for chinese/japanese keyboards
if (e.keyCode === 13 && !e.shiftKey) {
e.preventDefault();
}
// Submit the content when Enter key is pressed
if (content !== '' && e.keyCode === 13 && !e.shiftKey) {
submitHandler();
}
}
if (e.key === 'Escape') {
console.log('Escape');
}
}}
on:paste={async (e) => {
e = e.detail.event;
console.log(e);
}}
/>
</div>
{:else}
<textarea
id="chat-input"
class="scrollbar-hidden bg-transparent dark:text-gray-100 outline-none w-full py-3 px-1 rounded-xl resize-none h-[48px]"
{placeholder}
bind:value={content}
on:keydown={async (e) => {
e = e.detail.event;
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
if (
!$mobile ||
!(
'ontouchstart' in window ||
navigator.maxTouchPoints > 0 ||
navigator.msMaxTouchPoints > 0
)
) {
// Prevent Enter key from creating a new line
// Uses keyCode '13' for Enter key for chinese/japanese keyboards
if (e.keyCode === 13 && !e.shiftKey) {
e.preventDefault();
}
// Submit the content when Enter key is pressed
if (content !== '' && e.keyCode === 13 && !e.shiftKey) {
submitHandler();
}
}
if (e.key === 'Escape') {
console.log('Escape');
}
}}
rows="1"
on:input={async (e) => {
e.target.style.height = '';
e.target.style.height = Math.min(e.target.scrollHeight, 320) + 'px';
}}
on:focus={async (e) => {
e.target.style.height = '';
e.target.style.height = Math.min(e.target.scrollHeight, 320) + 'px';
}}
/>
{/if}
<div class="self-end mb-1.5 flex space-x-1 mr-1">
{#if content === ''}
<Tooltip content={$i18n.t('Record voice')}>
<button
id="voice-input-button"
class=" text-gray-600 dark:text-gray-300 hover:text-gray-700 dark:hover:text-gray-200 transition rounded-full p-1.5 mr-0.5 self-center"
type="button"
on:click={async () => {
try {
let stream = await navigator.mediaDevices
.getUserMedia({ audio: true })
.catch(function (err) {
toast.error(
$i18n.t(`Permission denied when accessing microphone: {{error}}`, {
error: err
})
);
return null;
});
if (stream) {
recording = true;
const tracks = stream.getTracks();
tracks.forEach((track) => track.stop());
}
stream = null;
} catch {
toast.error($i18n.t('Permission denied when accessing microphone'));
}
}}
aria-label="Voice Input"
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5 translate-y-[0.5px]"
>
<path d="M7 4a3 3 0 016 0v6a3 3 0 11-6 0V4z" />
<path
d="M5.5 9.643a.75.75 0 00-1.5 0V10c0 3.06 2.29 5.585 5.25 5.954V17.5h-1.5a.75.75 0 000 1.5h4.5a.75.75 0 000-1.5h-1.5v-1.546A6.001 6.001 0 0016 10v-.357a.75.75 0 00-1.5 0V10a4.5 4.5 0 01-9 0v-.357z"
/>
</svg>
</button>
</Tooltip>
{/if}
<div class=" flex items-center">
<div class=" flex items-center">
<Tooltip content={$i18n.t('Send message')}>
<button
id="send-message-button"
class="{content !== ''
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 self-center"
type="submit"
disabled={content === ''}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="size-6"
>
<path
fill-rule="evenodd"
d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
clip-rule="evenodd"
/>
</svg>
</button>
</Tooltip>
</div>
</div>
</div>
</div>
</div>
</form>
{/if}
</div>
</div>
</div>

View File

@ -9,14 +9,15 @@
import Message from './Messages/Message.svelte'; import Message from './Messages/Message.svelte';
import Loader from '../common/Loader.svelte'; import Loader from '../common/Loader.svelte';
import Spinner from '../common/Spinner.svelte'; import Spinner from '../common/Spinner.svelte';
import { getChannelMessages } from '$lib/apis/channels';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
export let channelId; export let messages = [];
export let top = false;
let messages = null; export let onLoad: Function = () => {};
let messagesCount = 50;
let messagesLoading = false; let messagesLoading = false;
const loadMoreMessages = async () => { const loadMoreMessages = async () => {
@ -25,19 +26,19 @@
element.scrollTop = element.scrollTop + 100; element.scrollTop = element.scrollTop + 100;
messagesLoading = true; messagesLoading = true;
messagesCount += 50;
await onLoad();
await tick(); await tick();
messagesLoading = false; messagesLoading = false;
}; };
</script> </script>
<div class="h-full flex pt-8"> {#if messages}
<div class="w-full pt-2"> <div class="h-full w-full flex-1 flex">
{#key channelId} <div class="w-full pt-2">
<div class="w-full"> <div class="w-full">
{#if messages.at(0)?.parentId !== null} {#if !top}
<Loader <Loader
on:visible={(e) => { on:visible={(e) => {
console.log('visible'); console.log('visible');
@ -54,10 +55,10 @@
{/if} {/if}
{#each messages as message, messageIdx (message.id)} {#each messages as message, messageIdx (message.id)}
<Message {channelId} id={message.id} content={message.content} /> <Message {message} />
{/each} {/each}
</div> </div>
<div class="pb-12" /> <div class="pb-6" />
{/key} </div>
</div> </div>
</div> {/if}

View File

@ -0,0 +1,13 @@
<script lang="ts">
import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
export let message;
</script>
{#if message}
<div>
<div>
<Markdown id={message.id} content={message.content} />
</div>
</div>
{/if}

View File

@ -49,7 +49,7 @@
export let autoScroll = false; export let autoScroll = false;
export let atSelectedModel: Model | undefined; export let atSelectedModel: Model | undefined = undefined;
export let selectedModels: ['']; export let selectedModels: [''];
let selectedModelIds = []; let selectedModelIds = [];

View File

@ -19,6 +19,7 @@
on:click={() => { on:click={() => {
showImagePreview = true; showImagePreview = true;
}} }}
type="button"
> >
<img src={_src} {alt} class={imageClassName} draggable="false" data-cy="image" /> <img src={_src} {alt} class={imageClassName} draggable="false" data-cy="image" />
</button> </button>

View File

@ -17,7 +17,8 @@
scrollPaginationEnabled, scrollPaginationEnabled,
currentChatPage, currentChatPage,
temporaryChatEnabled, temporaryChatEnabled,
channels channels,
socket
} from '$lib/stores'; } from '$lib/stores';
import { onMount, getContext, tick, onDestroy } from 'svelte'; import { onMount, getContext, tick, onDestroy } from 'svelte';
@ -151,7 +152,7 @@
}; };
const initChannels = async () => { const initChannels = async () => {
channels.set(await getChannels(localStorage.token)); await channels.set(await getChannels(localStorage.token));
}; };
const initChatList = async () => { const initChatList = async () => {

View File

@ -38,7 +38,7 @@
let loaded = false; let loaded = false;
const BREAKPOINT = 768; const BREAKPOINT = 768;
const setupSocket = (enableWebsocket) => { const setupSocket = async (enableWebsocket) => {
const _socket = io(`${WEBUI_BASE_URL}` || undefined, { const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
reconnection: true, reconnection: true,
reconnectionDelay: 1000, reconnectionDelay: 1000,
@ -49,7 +49,7 @@
auth: { token: localStorage.token } auth: { token: localStorage.token }
}); });
socket.set(_socket); await socket.set(_socket);
_socket.on('connect_error', (err) => { _socket.on('connect_error', (err) => {
console.log('connect_error', err); console.log('connect_error', err);
@ -127,7 +127,7 @@
await WEBUI_NAME.set(backendConfig.name); await WEBUI_NAME.set(backendConfig.name);
if ($config) { if ($config) {
setupSocket($config.features?.enable_websocket ?? true); await setupSocket($config.features?.enable_websocket ?? true);
if (localStorage.token) { if (localStorage.token) {
// Get Session User Info // Get Session User Info
@ -138,6 +138,8 @@
if (sessionUser) { if (sessionUser) {
// Save Session User to Store // Save Session User to Store
$socket.emit('user-join', { auth: { token: sessionUser.token } });
await user.set(sessionUser); await user.set(sessionUser);
await config.set(await getBackendConfig()); await config.set(await getBackendConfig());
} else { } else {