From f93c2e4a8d12104e035ece104d7d0d735a7f210c Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 23:06:34 -0800 Subject: [PATCH] feat: reactions --- .../3781e22d8b01_update_message_table.py | 70 ++++++++ backend/open_webui/models/channels.py | 8 +- backend/open_webui/models/messages.py | 111 +++++++++++-- backend/open_webui/routers/channels.py | 150 +++++++++++++++++- src/lib/apis/channels/index.ts | 71 +++++++++ src/lib/components/channel/Channel.svelte | 9 +- src/lib/components/channel/Messages.svelte | 29 +++- .../channel/Messages/Message.svelte | 88 +++++----- .../Messages/Message/ReactionPicker.svelte | 12 +- 9 files changed, 479 insertions(+), 69 deletions(-) create mode 100644 backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py diff --git a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py new file mode 100644 index 000000000..16fb0e85e --- /dev/null +++ b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py @@ -0,0 +1,70 @@ +"""Update message & channel tables + +Revision ID: 3781e22d8b01 +Revises: 7826ab40b532 +Create Date: 2024-12-30 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "3781e22d8b01" +down_revision = "7826ab40b532" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add 'type' column to the 'channel' table + op.add_column( + "channel", + sa.Column( + "type", + sa.Text(), + nullable=True, + ), + ) + + # Add 'parent_id' column to the 'message' table for threads + op.add_column( + "message", + sa.Column("parent_id", sa.Text(), nullable=True), + ) + + op.create_table( + "message_reaction", + sa.Column( + "id", sa.Text(), nullable=False, primary_key=True, unique=True + ), # Unique reaction ID + sa.Column("user_id", sa.Text(), nullable=False), # User who reacted + sa.Column( + "message_id", sa.Text(), nullable=False + ), # Message that was reacted to + sa.Column( + "name", sa.Text(), nullable=False + ), # Reaction name (e.g. "thumbs_up") + sa.Column( + "created_at", sa.BigInteger(), nullable=True + ), # Timestamp of when the reaction was added + ) + + op.create_table( + "channel_member", + sa.Column( + "id", sa.Text(), nullable=False, primary_key=True, unique=True + ), # Record ID for the membership row + sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel + sa.Column("user_id", sa.Text(), nullable=False), # Associated user + sa.Column( + "created_at", sa.BigInteger(), nullable=True + ), # Timestamp of when the user joined the channel + ) + + +def downgrade(): + # Revert 'type' column addition to the 'channel' table + op.drop_column("channel", "type") + op.drop_column("message", "parent_id") + op.drop_table("message_reaction") + op.drop_table("channel_member") diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index bc36146cf..92f238c3a 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -21,6 +21,7 @@ class Channel(Base): id = Column(Text, primary_key=True) user_id = Column(Text) + type = Column(Text, nullable=True) name = Column(Text) description = Column(Text, nullable=True) @@ -38,9 +39,11 @@ class ChannelModel(BaseModel): id: str user_id: str - description: Optional[str] = None + type: Optional[str] = None name: str + description: Optional[str] = None + data: Optional[dict] = None meta: Optional[dict] = None access_control: Optional[dict] = None @@ -64,12 +67,13 @@ class ChannelForm(BaseModel): class ChannelTable: def insert_new_channel( - self, form_data: ChannelForm, user_id: str + self, type: Optional[str], form_data: ChannelForm, user_id: str ) -> Optional[ChannelModel]: with get_db() as db: channel = ChannelModel( **{ **form_data.model_dump(), + "type": type, "name": form_data.name.lower(), "id": str(uuid.uuid4()), "user_id": user_id, diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 2a4322d0d..68e396bc2 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -17,6 +17,25 @@ from sqlalchemy.sql import exists #################### +class MessageReaction(Base): + __tablename__ = "message_reaction" + id = Column(Text, primary_key=True) + user_id = Column(Text) + message_id = Column(Text) + name = Column(Text) + created_at = Column(BigInteger) + + +class MessageReactionModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + user_id: str + message_id: str + name: str + created_at: int # timestamp in epoch + + class Message(Base): __tablename__ = "message" id = Column(Text, primary_key=True) @@ -24,6 +43,8 @@ class Message(Base): user_id = Column(Text) channel_id = Column(Text, nullable=True) + parent_id = Column(Text, nullable=True) + content = Column(Text) data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) @@ -39,6 +60,8 @@ class MessageModel(BaseModel): user_id: str channel_id: Optional[str] = None + parent_id: Optional[str] = None + content: str data: Optional[dict] = None meta: Optional[dict] = None @@ -54,10 +77,21 @@ class MessageModel(BaseModel): class MessageForm(BaseModel): content: str + parent_id: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None +class Reactions(BaseModel): + name: str + user_ids: list[str] + count: int + + +class MessageResponse(MessageModel): + reactions: list[Reactions] + + class MessageTable: def insert_new_message( self, form_data: MessageForm, channel_id: str, user_id: str @@ -71,6 +105,7 @@ class MessageTable: "id": id, "user_id": user_id, "channel_id": channel_id, + "parent_id": form_data.parent_id, "content": form_data.content, "data": form_data.data, "meta": form_data.meta, @@ -85,10 +120,19 @@ class MessageTable: db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageModel]: + def get_message_by_id(self, id: str) -> Optional[MessageResponse]: with get_db() as db: message = db.get(Message, id) - return MessageModel.model_validate(message) if message else None + if not message: + return None + + reactions = self.get_reactions_by_message_id(id) + return MessageResponse( + **{ + **MessageModel.model_validate(message).model_dump(), + "reactions": reactions, + } + ) def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 @@ -104,20 +148,6 @@ class MessageTable: ) return [MessageModel.model_validate(message) for message in all_messages] - def get_messages_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: - with get_db() as db: - all_messages = ( - db.query(Message) - .filter_by(user_id=user_id) - .order_by(Message.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - return [MessageModel.model_validate(message) for message in all_messages] - def update_message_by_id( self, id: str, form_data: MessageForm ) -> Optional[MessageModel]: @@ -131,9 +161,58 @@ class MessageTable: db.refresh(message) return MessageModel.model_validate(message) if message else None + def add_reaction_to_message( + self, id: str, user_id: str, name: str + ) -> Optional[MessageReactionModel]: + with get_db() as db: + reaction_id = str(uuid.uuid4()) + reaction = MessageReactionModel( + id=reaction_id, + user_id=user_id, + message_id=id, + name=name, + created_at=int(time.time_ns()), + ) + result = MessageReaction(**reaction.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return MessageReactionModel.model_validate(result) if result else None + + def get_reactions_by_message_id(self, id: str) -> list[Reactions]: + with get_db() as db: + all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() + + reactions = {} + for reaction in all_reactions: + if reaction.name not in reactions: + reactions[reaction.name] = { + "name": reaction.name, + "user_ids": [], + "count": 0, + } + reactions[reaction.name]["user_ids"].append(reaction.user_id) + reactions[reaction.name]["count"] += 1 + + return [Reactions(**reaction) for reaction in reactions.values()] + + def remove_reaction_by_id_and_user_id_and_name( + self, id: str, user_id: str, name: str + ) -> bool: + with get_db() as db: + db.query(MessageReaction).filter_by( + message_id=id, user_id=user_id, name=name + ).delete() + db.commit() + return True + def delete_message_by_id(self, id: str) -> bool: with get_db() as db: db.query(Message).filter_by(id=id).delete() + + # Delete all reactions to this message + db.query(MessageReaction).filter_by(message_id=id).delete() + db.commit() return True diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 292f33e78..c66e637ca 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -11,7 +11,12 @@ from open_webui.socket.main import sio, get_user_ids_from_room from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels, ChannelModel, ChannelForm -from open_webui.models.messages import Messages, MessageModel, MessageForm +from open_webui.models.messages import ( + Messages, + MessageModel, + MessageResponse, + MessageForm, +) from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT @@ -49,7 +54,7 @@ async def get_channels(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[ChannelModel]) async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)): try: - channel = Channels.insert_new_channel(form_data, user.id) + channel = Channels.insert_new_channel(None, form_data, user.id) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -134,11 +139,11 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): ############################ -class MessageUserModel(MessageModel): +class MessageUserResponse(MessageResponse): user: UserNameResponse -@router.get("/{id}/messages", response_model=list[MessageUserModel]) +@router.get("/{id}/messages", response_model=list[MessageUserResponse]) async def get_channel_messages( id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user) ): @@ -165,9 +170,10 @@ async def get_channel_messages( users[message.user_id] = user messages.append( - MessageUserModel( + MessageUserResponse( **{ **message.model_dump(), + "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -333,6 +339,140 @@ async def update_message_by_id( ) +############################ +# AddReactionToMessage +############################ + + +class ReactionForm(BaseModel): + name: str + + +@router.post("/{id}/messages/{message_id}/reactions/add", response_model=bool) +async def add_reaction_to_message( + id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + Messages.add_reaction_to_message(message_id, user.id, form_data.name) + + message = Messages.get_message_by_id(message_id) + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message:reaction", + "data": { + **message.model_dump(), + "user": UserNameResponse(**user.model_dump()).model_dump(), + "name": form_data.name, + }, + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + + return True + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# RemoveReactionById +############################ + + +@router.post("/{id}/messages/{message_id}/reactions/remove", response_model=bool) +async def remove_reaction_by_id_and_user_id_and_name( + id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + Messages.remove_reaction_by_id_and_user_id_and_name( + message_id, user.id, form_data.name + ) + + message = Messages.get_message_by_id(message_id) + await sio.emit( + "channel-events", + { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message:reaction", + "data": { + **message.model_dump(), + "user": UserNameResponse(**user.model_dump()).model_dump(), + "name": form_data.name, + }, + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + }, + to=f"channel:{channel.id}", + ) + + return True + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # DeleteMessageById ############################ diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 607a8f900..1b9cd3ebf 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -285,6 +285,77 @@ export const updateMessage = async ( return res; }; +export const addReaction = async (token: string = '', channel_id: string, message_id: string, name: string) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/reactions/add`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ name }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +} + + +export const removeReaction = async (token: string = '', channel_id: string, message_id: string, name: string) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/reactions/remove`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ name }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +} + export const deleteMessage = async (token: string = '', channel_id: string, message_id: string) => { let error = null; diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index a27c88540..36f57e159 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -1,13 +1,13 @@