feat: dm channels

This commit is contained in:
Timothy Jaeryang Baek
2025-11-27 07:27:32 -05:00
parent f2c56fc839
commit acccb9afdd
13 changed files with 989 additions and 216 deletions

View File

@@ -113,22 +113,24 @@ class ChannelResponse(ChannelModel):
class ChannelForm(BaseModel):
type: Optional[str] = None
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
user_ids: Optional[list[str]] = None
class ChannelTable:
def insert_new_channel(
self, type: Optional[str], form_data: ChannelForm, user_id: str
self, form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
"type": type,
"type": form_data.type if form_data.type else None,
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
@@ -136,9 +138,34 @@ class ChannelTable:
"updated_at": int(time.time_ns()),
}
)
new_channel = Channel(**channel.model_dump())
if form_data.type == "dm":
# For direct message channels, automatically add the specified users as members
user_ids = form_data.user_ids or []
if user_id not in user_ids:
user_ids.append(user_id) # Ensure the creator is also a member
for uid in user_ids:
channel_member = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel.id,
"user_id": uid,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"joined_at": int(time.time_ns()),
"left_at": None,
"last_read_at": int(time.time_ns()),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_membership = ChannelMember(**channel_member.model_dump())
db.add(new_membership)
db.add(new_channel)
db.commit()
return channel
@@ -152,12 +179,41 @@ class ChannelTable:
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
channel_list = []
for channel in channels:
if channel.type == "dm":
membership = self.get_member_by_channel_and_user_id(channel.id, user_id)
if membership and membership.is_active:
channel_list.append(channel)
else:
if channel.user_id == user_id or has_access(
user_id, permission, channel.access_control
):
channel_list.append(channel)
return channel_list
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db:
subquery = (
db.query(ChannelMember.channel_id)
.filter(ChannelMember.user_id.in_(user_ids))
.group_by(ChannelMember.channel_id)
.having(func.count(ChannelMember.user_id) == len(user_ids))
.subquery()
)
channel = (
db.query(Channel)
.filter(
Channel.id.in_(subquery),
Channel.type == "dm",
)
.first()
)
return ChannelModel.model_validate(channel) if channel else None
def join_channel(
self, channel_id: str, user_id: str
@@ -233,6 +289,18 @@ class ChannelTable:
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
with get_db() as db:
memberships = (
db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id)
.all()
)
return [
ChannelMemberModel.model_validate(membership)
for membership in memberships
]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
with get_db() as db:
membership = (
@@ -271,6 +339,27 @@ class ChannelTable:
db.commit()
return True
def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool
) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.is_active = is_active
membership.updated_at = int(time.time_ns())
db.commit()
return True
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
@@ -278,7 +367,6 @@ class ChannelTable:
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
ChannelMember.is_active == True,
)
.first()
)

View File

@@ -13,6 +13,7 @@ from open_webui.socket.main import (
get_active_status_by_user_id,
)
from open_webui.models.users import (
UserIdNameResponse,
UserListResponse,
UserModelResponse,
Users,
@@ -66,6 +67,9 @@ router = APIRouter()
class ChannelListItemResponse(ChannelModel):
user_ids: Optional[list[str]] = None # 'dm' channels only
users: Optional[list[UserIdNameResponse]] = None # 'dm' channels only
last_message_at: Optional[int] = None # timestamp in epoch (time_ns)
unread_count: int = 0
@@ -85,9 +89,23 @@ async def get_channels(user=Depends(get_verified_user)):
channel.id, user.id, channel_member.last_read_at if channel_member else None
)
user_ids = None
users = None
if channel.type == "dm":
user_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(channel.id)
]
users = [
UserIdNameResponse(**user.model_dump())
for user in Users.get_users_by_user_ids(user_ids)
]
channel_list.append(
ChannelListItemResponse(
**channel.model_dump(),
user_ids=user_ids,
users=users,
last_message_at=last_message_at,
unread_count=unread_count,
)
@@ -111,7 +129,15 @@ async def get_all_channels(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[ChannelModel])
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
try:
channel = Channels.insert_new_channel(None, form_data, user.id)
if form_data.type == "dm" and len(form_data.user_ids) == 1:
existing_channel = Channels.get_dm_channel_by_user_ids(
[user.id, form_data.user_ids[0]]
)
if existing_channel:
Channels.update_member_active_status(existing_channel.id, user.id, True)
return ChannelModel(**existing_channel.model_dump())
channel = Channels.insert_new_channel(form_data, user.id)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
@@ -125,7 +151,15 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
############################
@router.get("/{id}", response_model=Optional[ChannelResponse])
class ChannelFullResponse(ChannelResponse):
user_ids: Optional[list[str]] = None # 'dm' channels only
users: Optional[list[UserIdNameResponse]] = None # 'dm' channels only
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
unread_count: int = 0
@router.get("/{id}", response_model=Optional[ChannelFullResponse])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
@@ -133,33 +167,82 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
user_ids = None
users = None
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
user_ids = [
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
]
users = [
UserIdNameResponse(**user.model_dump())
for user in Users.get_users_by_user_ids(user_ids)
]
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
unread_count = Messages.get_unread_message_count(
channel.id, user.id, channel_member.last_read_at if channel_member else None
)
write_access = has_access(
user.id, type="write", access_control=channel.access_control, strict=False
)
return ChannelFullResponse(
**{
**channel.model_dump(),
"user_ids": user_ids,
"users": users,
"write_access": True,
"user_count": len(user_ids),
"last_read_at": channel_member.last_read_at if channel_member else None,
"unread_count": unread_count,
}
)
user_count = len(get_users_with_access("read", channel.access_control))
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
return ChannelResponse(
**{
**channel.model_dump(),
"write_access": write_access or user.role == "admin",
"user_count": user_count,
}
)
write_access = has_access(
user.id, type="write", access_control=channel.access_control, strict=False
)
user_count = len(get_users_with_access("read", channel.access_control))
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
unread_count = Messages.get_unread_message_count(
channel.id, user.id, channel_member.last_read_at if channel_member else None
)
return ChannelFullResponse(
**{
**channel.model_dump(),
"user_ids": user_ids,
"users": users,
"write_access": write_access or user.role == "admin",
"user_count": user_count,
"last_read_at": channel_member.last_read_at if channel_member else None,
"unread_count": unread_count,
}
)
############################
# GetChannelMembersById
############################
PAGE_ITEM_COUNT = 30
@router.get("/{id}/users", response_model=UserListResponse)
async def get_channel_users_by_id(
@router.get("/{id}/members", response_model=UserListResponse)
async def get_channel_members_by_id(
id: str,
query: Optional[str] = None,
order_by: Optional[str] = None,
@@ -179,36 +262,90 @@ async def get_channel_users_by_id(
page = max(1, page)
skip = (page - 1) * limit
filter = {
"roles": ["!pending"],
}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
permitted_ids = get_permitted_group_and_user_ids("read", channel.access_control)
if permitted_ids:
filter["user_ids"] = permitted_ids.get("user_ids")
filter["group_ids"] = permitted_ids.get("group_ids")
result = Users.get_users(filter=filter, skip=skip, limit=limit)
users = result["users"]
total = result["total"]
return {
"users": [
UserModelResponse(
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
for user in users
],
"total": total,
}
user_ids = [
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
]
users = Users.get_users_by_user_ids(user_ids)
total = len(users)
return {
"users": [
UserModelResponse(
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
)
for user in users
],
"total": total,
}
else:
filter = {
"roles": ["!pending"],
}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
permitted_ids = get_permitted_group_and_user_ids("read", channel.access_control)
if permitted_ids:
filter["user_ids"] = permitted_ids.get("user_ids")
filter["group_ids"] = permitted_ids.get("group_ids")
result = Users.get_users(filter=filter, skip=skip, limit=limit)
users = result["users"]
total = result["total"]
return {
"users": [
UserModelResponse(
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
)
for user in users
],
"total": total,
}
#################################################
# UpdateIsActiveMemberByIdAndUserId
#################################################
class UpdateActiveMemberForm(BaseModel):
is_active: bool
@router.post("/{id}/members/active", response_model=bool)
async def update_is_active_member_by_id_and_user_id(
id: str,
form_data: UpdateActiveMemberForm,
user=Depends(get_verified_user),
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
Channels.update_member_active_status(channel.id, user.id, form_data.is_active)
return True
############################
@@ -278,16 +415,22 @@ async def get_channel_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
channel_member = Channels.join_channel(
id, user.id
) # Ensure user is a member of the channel
channel_member = Channels.join_channel(
id, user.id
) # Ensure user is a member of the channel
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
users = {}
@@ -533,16 +676,30 @@ async def new_message_handler(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.insert_new_message(form_data, channel.id, user.id)
if message:
if channel.type == "dm":
members = Channels.get_members_by_channel_id(channel.id)
for member in members:
if not member.is_active:
Channels.update_member_active_status(
channel.id, member.user_id, True
)
message = Messages.get_message_by_id(message.id)
event_data = {
"channel_id": channel.id,
@@ -641,12 +798,18 @@ async def get_channel_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
@@ -690,12 +853,18 @@ async def get_channel_thread_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
users = {}
@@ -749,14 +918,22 @@ async def update_message_by_id(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(
user.id, type="read", access_control=channel.access_control
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.update_message_by_id(message_id, form_data)
@@ -805,12 +982,18 @@ async def add_reaction_to_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
@@ -868,12 +1051,18 @@ async def remove_reaction_by_id_and_user_id_and_name(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
@@ -945,16 +1134,25 @@ async def delete_message_by_id(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
if channel.type == "dm":
if not Channels.is_user_channel_member(channel.id, user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(
user.id,
type="write",
access_control=channel.access_control,
strict=False,
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
Messages.delete_message_by_id(message_id)