feat/refac: group members db table (#19239)

* refac: group members table db migration

* refac: group members backend

* refac: group members frontend

* refac: group members frontend integration

* refac: styling
This commit is contained in:
Tim Baek
2025-11-18 03:59:56 -05:00
committed by GitHub
13 changed files with 600 additions and 207 deletions

View File

@@ -0,0 +1,146 @@
"""add_group_member_table
Revision ID: 37f288994c47
Revises: a5c220713937
Create Date: 2025-11-17 03:45:25.123939
"""
import uuid
import time
import json
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "37f288994c47"
down_revision: Union[str, None] = "a5c220713937"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create new table
op.create_table(
"group_member",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column(
"group_id",
sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
print(results)
# 3. Insert members into group_member table
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
now = int(time.time())
for group_id, user_ids in results:
if not user_ids:
continue
if isinstance(user_ids, str):
try:
user_ids = json.loads(user_ids)
except Exception:
continue # skip invalid JSON
if not isinstance(user_ids, list):
continue
rows = [
{
"id": str(uuid.uuid4()),
"group_id": group_id,
"user_id": uid,
"created_at": now,
"updated_at": now,
}
for uid in user_ids
]
if rows:
connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column
with op.batch_alter_table("group") as batch:
batch.drop_column("user_ids")
def downgrade():
# Reverse: restore user_ids column
with op.batch_alter_table("group") as batch:
batch.add_column(sa.Column("user_ids", sa.JSON()))
connection = op.get_bind()
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()),
)
# Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results:
members = connection.execute(
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
member_ids = [m[0] for m in members]
connection.execute(
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
# Drop the new table
op.drop_table("group_member")

View File

@@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
log = logging.getLogger(__name__)
@@ -35,7 +35,6 @@ class Group(Base):
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@@ -53,12 +52,33 @@ class GroupModel(BaseModel):
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupMember(Base):
__tablename__ = "group_member"
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=True)
updated_at = Column(BigInteger, nullable=True)
class GroupMemberModel(BaseModel):
id: str
group_id: str
user_id: str
created_at: Optional[int] = None # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch
####################
# Forms
####################
@@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
member_count: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -87,7 +107,7 @@ class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm, UserIdsForm):
class GroupUpdateForm(GroupForm):
pass
@@ -131,12 +151,8 @@ class GroupTable:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.order_by(Group.updated_at.desc())
.all()
]
@@ -149,12 +165,46 @@ class GroupTable:
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
with get_db() as db:
members = (
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
)
if not members:
return None
return [m[0] for m in members]
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
with get_db() as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
# Insert new members
now = int(time.time())
new_members = [
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
for user_id in user_ids
]
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(self, id: str) -> int:
with get_db() as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
@@ -195,20 +245,29 @@ class GroupTable:
def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db:
try:
groups = self.get_groups_by_member_id(user_id)
# Find all groups the user belongs to
groups = (
db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
)
# Remove the user from each group
for group in groups:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
db.query(GroupMember).filter(
GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete()
db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())}
)
db.commit()
return True
except Exception:
db.rollback()
return False
def create_groups_by_group_names(
@@ -246,37 +305,61 @@ class GroupTable:
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
with get_db() as db:
try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
now = int(time.time())
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
# 1. Groups that SHOULD contain the user
target_groups = (
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups}
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
# 2. Groups the user is CURRENTLY in
existing_group_ids = {
g.id
for g in db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
}
# 3. Determine adds + removals
groups_to_add = target_group_ids - existing_group_ids
groups_to_remove = existing_group_ids - target_group_ids
# 4. Remove in one bulk delete
if groups_to_remove:
db.query(GroupMember).filter(
GroupMember.user_id == user_id,
GroupMember.group_id.in_(groups_to_remove),
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
)
# 5. Bulk insert missing memberships
for group_id in groups_to_add:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
# Add user to new groups
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
if groups_to_add:
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
{"updated_at": now}, synchronize_session=False
)
db.commit()
return True
except Exception as e:
log.exception(e)
db.rollback()
return False
def add_users_to_group(
@@ -288,21 +371,31 @@ class GroupTable:
if not group:
return None
group_user_ids = group.user_ids
if not group_user_ids or not isinstance(group_user_ids, list):
group_user_ids = []
now = int(time.time())
group_user_ids = list(set(group_user_ids)) # Deduplicate
for user_id in user_ids or []:
try:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
db.flush() # Detect unique constraint violation early
except Exception:
db.rollback() # Clear failed INSERT
db.begin() # Start a new transaction
continue # Duplicate → ignore
for user_id in user_ids:
if user_id not in group_user_ids:
group_user_ids.append(user_id)
group.user_ids = group_user_ids
group.updated_at = int(time.time())
group.updated_at = now
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
@@ -316,23 +409,22 @@ class GroupTable:
if not group:
return None
group_user_ids = group.user_ids
if not group_user_ids or not isinstance(group_user_ids, list):
if not user_ids:
return GroupModel.model_validate(group)
group_user_ids = list(set(group_user_ids)) # Deduplicate
# Remove each user from group_member
for user_id in user_ids:
if user_id in group_user_ids:
group_user_ids.remove(user_id)
db.query(GroupMember).filter(
GroupMember.group_id == id, GroupMember.user_id == user_id
).delete()
group.user_ids = group_user_ids
# Update group timestamp
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None

View File

@@ -6,7 +6,7 @@ from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from open_webui.models.groups import Groups, GroupMember
from open_webui.utils.misc import throttle
@@ -95,8 +95,12 @@ class UpdateProfileForm(BaseModel):
date_of_birth: Optional[datetime.date] = None
class UserGroupIdsModel(UserModel):
group_ids: list[str] = []
class UserListResponse(BaseModel):
users: list[UserModel]
users: list[UserGroupIdsModel]
total: int
@@ -222,7 +226,10 @@ class UsersTable:
limit: Optional[int] = None,
) -> dict:
with get_db() as db:
query = db.query(User)
# Join GroupMember so we can order by group_id when requested
query = db.query(User).outerjoin(
GroupMember, GroupMember.user_id == User.id
)
if filter:
query_key = filter.get("query")
@@ -237,7 +244,16 @@ class UsersTable:
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if order_by and order_by.startswith("group_id:"):
group_id = order_by.split(":", 1)[1]
if direction == "asc":
query = query.order_by((GroupMember.group_id == group_id).asc())
else:
query = query.order_by(
(GroupMember.group_id == group_id).desc()
)
elif order_by == "name":
if direction == "asc":
query = query.order_by(User.name.asc())
else:

View File

@@ -33,9 +33,18 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)):
if user.role == "admin":
return Groups.get_groups()
groups = Groups.get_groups()
else:
return Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id)
return [
GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
for group in groups
if group
]
############################
@@ -48,7 +57,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
if group:
return group
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -71,7 +83,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id)
if group:
return group
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -89,12 +104,12 @@ async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
):
try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data)
if group:
return group
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -123,7 +138,10 @@ async def add_user_to_group(
group = Groups.add_users_to_group(id, form_data.user_ids)
if group:
return group
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -144,7 +162,10 @@ async def remove_users_from_group(
try:
group = Groups.remove_users_from_group(id, form_data.user_ids)
if group:
return group
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -349,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group"""
member_ids = Groups.get_group_user_ids_by_id(group.id)
members = []
for user_id in group.user_ids:
for user_id in member_ids:
user = Users.get_user_by_id(user_id)
if user:
members.append(
@@ -796,9 +798,11 @@ async def create_group(
update_form = GroupUpdateForm(
name=new_group.name,
description=new_group.description,
user_ids=member_ids,
)
Groups.update_group_by_id(new_group.id, update_form)
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
new_group = Groups.get_group_by_id(new_group.id)
return group_to_scim(new_group, request)
@@ -830,7 +834,7 @@ async def update_group(
# Handle members if provided
if group_data.members is not None:
member_ids = [member.value for member in group_data.members]
update_form.user_ids = member_ids
Groups.set_group_user_ids_by_id(group_id, member_ids)
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)
@@ -863,7 +867,6 @@ async def patch_group(
update_form = GroupUpdateForm(
name=group.name,
description=group.description,
user_ids=group.user_ids.copy() if group.user_ids else [],
)
for operation in patch_data.Operations:
@@ -876,21 +879,22 @@ async def patch_group(
update_form.name = value
elif path == "members":
# Replace all members
update_form.user_ids = [member["value"] for member in value]
Groups.set_group_user_ids_by_id(
group_id, [member["value"] for member in value]
)
elif op == "add":
if path == "members":
# Add members
if isinstance(value, list):
for member in value:
if isinstance(member, dict) and "value" in member:
if member["value"] not in update_form.user_ids:
update_form.user_ids.append(member["value"])
Groups.add_users_to_group(group_id, [member["value"]])
elif op == "remove":
if path and path.startswith("members[value eq"):
# Remove specific member
member_id = path.split('"')[1]
if member_id in update_form.user_ids:
update_form.user_ids.remove(member_id)
Groups.remove_users_from_group(group_id, [member_id])
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)

View File

@@ -16,6 +16,7 @@ from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.users import (
UserModel,
UserGroupIdsModel,
UserListResponse,
UserInfoListResponse,
UserIdNameListResponse,
@@ -91,7 +92,25 @@ async def get_users(
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
result = Users.get_users(filter=filter, skip=skip, limit=limit)
users = result["users"]
total = result["total"]
return {
"users": [
UserGroupIdsModel(
**{
**user.model_dump(),
"group_ids": [
group.id for group in Groups.get_groups_by_member_id(user.id)
],
}
)
for user in users
],
"total": total,
}
@router.get("/all", response_model=UserInfoListResponse)

View File

@@ -1137,22 +1137,21 @@ class OAuthManager:
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
Groups.remove_users_from_group(group_model.id, [user.id])
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
),
overwrite=False,
)
# Add user to new groups
@@ -1168,22 +1167,21 @@ class OAuthManager:
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
user_ids = group_model.user_ids
user_ids.append(user.id)
Groups.add_users_to_group(group_model.id, [user.id])
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
),
overwrite=False,
)
async def _process_picture_url(