diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index bceb72572..931711b9e 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -70,6 +70,13 @@ class UserResponse(BaseModel): profile_image_url: str +class UserNameResponse(BaseModel): + id: str + name: str + role: str + profile_image_url: str + + class UserRoleUpdateForm(BaseModel): id: str role: str diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 095cd0fcc..295a25a42 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -8,6 +8,8 @@ from pydantic import BaseModel from open_webui.socket.main import sio +from open_webui.models.users import Users, UserNameResponse + from open_webui.models.channels import Channels, ChannelModel, ChannelForm from open_webui.models.messages import Messages, MessageModel, MessageForm @@ -60,7 +62,11 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user ############################ -@router.get("/{id}/messages", response_model=list[MessageModel]) +class MessageUserModel(MessageModel): + user: UserNameResponse + + +@router.get("/{id}/messages", response_model=list[MessageUserModel]) async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)): channel = Channels.get_channel_by_id(id) if not channel: @@ -76,7 +82,25 @@ async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified limit = 50 skip = (page - 1) * limit - return Messages.get_messages_by_channel_id(id, skip, limit) + message_list = Messages.get_messages_by_channel_id(id, skip, limit) + users = {} + + messages = [] + for message in message_list: + if message.user_id not in users: + user = Users.get_user_by_id(message.user_id) + users[message.user_id] = user + + messages.append( + MessageUserModel( + **{ + **message.model_dump(), + "user": UserNameResponse(**users[message.user_id].model_dump()), + } + ) + ) + + return messages ############################ @@ -108,7 +132,13 @@ async def post_new_message( { "channel_id": channel.id, "message_id": message.id, - "data": {"type": "message", "data": message.model_dump()}, + "data": { + "type": "message", + "data": { + **message.model_dump(), + "user": UserNameResponse(**user.model_dump()).model_dump(), + }, + }, }, to=f"channel:{channel.id}", )