diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 03a24abf3..4f095a8b5 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -89,6 +89,8 @@ class Reactions(BaseModel): class MessageResponse(MessageModel): + latest_reply_at: Optional[int] + reply_count: int reactions: list[Reactions] @@ -127,13 +129,34 @@ class MessageTable: return None reactions = self.get_reactions_by_message_id(id) + replies = self.get_replies_by_message_id(id) + return MessageResponse( **{ **MessageModel.model_validate(message).model_dump(), + "latest_reply_at": replies[0].created_at if replies else None, + "reply_count": len(replies), "reactions": reactions, } ) + def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + with get_db() as db: + all_messages = ( + db.query(Message) + .filter_by(parent_id=id) + .order_by(Message.created_at.desc()) + .all() + ) + return [MessageModel.model_validate(message) for message in all_messages] + + def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: + with get_db() as db: + return [ + message.user_id + for message in db.query(Message).filter_by(parent_id=id).all() + ] + def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: @@ -166,9 +189,9 @@ class MessageTable: .all() ) - return [MessageModel.model_validate(message)] + [ + return [ MessageModel.model_validate(message) for message in all_messages - ] + ] + [MessageModel.model_validate(message)] def update_message_by_id( self, id: str, form_data: MessageForm diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 9e12911cc..d2bd99e6f 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -169,10 +169,15 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id) users[message.user_id] = user + replies = Messages.get_replies_by_message_id(message.id) + latest_reply_at = replies[0].created_at if replies else None + messages.append( MessageUserResponse( **{ **message.model_dump(), + "reply_count": len(replies), + "latest_reply_at": latest_reply_at, "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } @@ -242,10 +247,17 @@ async def post_new_message( "message_id": message.id, "data": { "type": "message", - "data": { - **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), - }, + "data": MessageUserResponse( + **{ + **message.model_dump(), + "reply_count": 0, + "latest_reply_at": None, + "reactions": Messages.get_reactions_by_message_id( + message.id + ), + "user": UserNameResponse(**user.model_dump()), + } + ).model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -257,6 +269,35 @@ async def post_new_message( to=f"channel:{channel.id}", ) + if message.parent_id: + # If this message is a reply, emit to the parent message as well + parent_message = Messages.get_message_by_id(message.parent_id) + + if parent_message: + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": parent_message.id, + "data": { + "type": "message:reply", + "data": MessageUserResponse( + **{ + **parent_message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id( + parent_message.user_id + ).model_dump() + ), + } + ).model_dump(), + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") background_tasks.add_task( @@ -275,6 +316,49 @@ async def post_new_message( ) +############################ +# GetChannelMessage +############################ + + +@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageUserResponse]) +async def get_channel_message( + id: str, message_id: str, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + return MessageUserResponse( + **{ + **message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id(message.user_id).model_dump() + ), + } + ) + + ############################ # GetChannelThreadMessages ############################ @@ -316,6 +400,8 @@ async def get_channel_thread_messages( MessageUserResponse( **{ **message.model_dump(), + "reply_count": 0, + "latest_reply_at": None, "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } @@ -372,10 +458,14 @@ async def update_message_by_id( "message_id": message.id, "data": { "type": "message:update", - "data": { - **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), - }, + "data": MessageUserResponse( + **{ + **message.model_dump(), + "user": UserNameResponse( + **user.model_dump() + ).model_dump(), + } + ).model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -430,18 +520,17 @@ async def add_reaction_to_message( try: Messages.add_reaction_to_message(message_id, user.id, form_data.name) - message = Messages.get_message_by_id(message_id) + await sio.emit( "channel-events", { "channel_id": channel.id, "message_id": message.id, "data": { - "type": "message:reaction", + "type": "message:reaction:add", "data": { **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), "name": form_data.name, }, }, @@ -505,10 +594,9 @@ async def remove_reaction_by_id_and_user_id_and_name( "channel_id": channel.id, "message_id": message.id, "data": { - "type": "message:reaction", + "type": "message:reaction:remove", "data": { **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), "name": form_data.name, }, }, diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 99ea44614..965d24331 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -250,6 +250,7 @@ export const getChannelThreadMessages = async ( } type MessageForm = { + parent_id?: string; content: string; data?: object; meta?: object; diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 93cc6e590..776f5605f 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -74,15 +74,17 @@ const data = event?.data?.data ?? null; if (type === 'message') { - messages = [data, ...messages]; + if ((data?.parent_id ?? null) === null) { + messages = [data, ...messages]; - if (typingUsers.find((user) => user.id === event.user.id)) { - typingUsers = typingUsers.filter((user) => user.id !== event.user.id); - } + if (typingUsers.find((user) => user.id === event.user.id)) { + typingUsers = typingUsers.filter((user) => user.id !== event.user.id); + } - await tick(); - if (scrollEnd) { - messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight; + await tick(); + if (scrollEnd) { + messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight; + } } } else if (type === 'message:update') { const idx = messages.findIndex((message) => message.id === data.id); @@ -92,7 +94,7 @@ } } else if (type === 'message:delete') { messages = messages.filter((message) => message.id !== data.id); - } else if (type === 'message:reaction') { + } else if (type.includes('message:reaction')) { const idx = messages.findIndex((message) => message.id === data.id); if (idx !== -1) { messages[idx] = data; diff --git a/src/lib/components/channel/Messages/Message.svelte b/src/lib/components/channel/Messages/Message.svelte index d9f8389ec..19281aeea 100644 --- a/src/lib/components/channel/Messages/Message.svelte +++ b/src/lib/components/channel/Messages/Message.svelte @@ -29,6 +29,7 @@ import ChatBubbleOvalEllipsis from '$lib/components/icons/ChatBubbleOvalEllipsis.svelte'; import FaceSmile from '$lib/components/icons/FaceSmile.svelte'; import ReactionPicker from './Message/ReactionPicker.svelte'; + import ChevronRight from '$lib/components/icons/ChevronRight.svelte'; export let message; export let showUserProfile = true; @@ -324,6 +325,29 @@ {/if} + + {#if message.reply_count > 0} +
+ +
+ {/if} {/if} diff --git a/src/lib/components/channel/Messages/Message/ReactionPicker.svelte b/src/lib/components/channel/Messages/Message/ReactionPicker.svelte index 59211b73f..a14971b4e 100644 --- a/src/lib/components/channel/Messages/Message/ReactionPicker.svelte +++ b/src/lib/components/channel/Messages/Message/ReactionPicker.svelte @@ -75,7 +75,7 @@
-
- {#if Object.keys(emojis).length === 0} -
No results
- {:else} - {#each Object.keys(emojiGroups) as group} - {@const groupEmojis = emojiGroups[group].filter((emoji) => emojis[emoji])} - {#if groupEmojis.length > 0} -
-
- {group} -
+ {#if show} +
+ {#if Object.keys(emojis).length === 0} +
No results
+ {:else} + {#each Object.keys(emojiGroups) as group} + {@const groupEmojis = emojiGroups[group].filter((emoji) => emojis[emoji])} + {#if groupEmojis.length > 0} +
+
+ {group} +
-
- {#each groupEmojis as emoji (emoji)} - `:${code}:`) - .join(', ')} - placement="top" - > - - - {/each} + + + {/each} +
-
- {/if} - {/each} - {/if} -
+ {/if} + {/each} + {/if} +
+ {/if}
diff --git a/src/lib/components/channel/Thread.svelte b/src/lib/components/channel/Thread.svelte index 69a658f33..9ea2a37d7 100644 --- a/src/lib/components/channel/Thread.svelte +++ b/src/lib/components/channel/Thread.svelte @@ -3,12 +3,13 @@ import { socket } from '$lib/stores'; - import { getChannelThreadMessages } from '$lib/apis/channels'; + import { getChannelThreadMessages, sendMessage } from '$lib/apis/channels'; import XMark from '$lib/components/icons/XMark.svelte'; import MessageInput from './MessageInput.svelte'; import Messages from './Messages.svelte'; import { onMount } from 'svelte'; + import { toast } from 'svelte-sonner'; export let threadId = null; export let channel = null; @@ -43,10 +44,19 @@ } }; - const submitHandler = async (message) => { - // if (message) { - // await sendMessage(localStorage.token, channel.id, message, threadId); - // } + const submitHandler = async ({ content, data }) => { + if (!content) { + return; + } + + const res = await sendMessage(localStorage.token, channel.id, { + parent_id: threadId, + content: content, + data: data + }).catch((error) => { + toast.error(error); + return null; + }); }; const onChange = async () => {