feat: user list in channels
This commit is contained in:
@@ -99,7 +99,16 @@ class UserGroupIdsModel(UserModel):
|
||||
group_ids: list[str] = []
|
||||
|
||||
|
||||
class UserModelResponse(UserModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserModelResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class UserGroupIdsListResponse(BaseModel):
|
||||
users: list[UserGroupIdsModel]
|
||||
total: int
|
||||
|
||||
@@ -239,6 +248,31 @@ class UsersTable:
|
||||
)
|
||||
)
|
||||
|
||||
user_ids = filter.get("user_ids")
|
||||
if user_ids:
|
||||
query = query.filter(User.id.in_(user_ids))
|
||||
|
||||
group_ids = filter.get("group_ids")
|
||||
if group_ids:
|
||||
query = query.filter(
|
||||
exists(
|
||||
select(GroupMember.id).where(
|
||||
GroupMember.user_id == User.id,
|
||||
GroupMember.group_id.in_(group_ids),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
roles = filter.get("roles")
|
||||
if roles:
|
||||
include_roles = [role for role in roles if not role.startswith("!")]
|
||||
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
||||
|
||||
if include_roles:
|
||||
query = query.filter(User.role.in_(include_roles))
|
||||
if exclude_roles:
|
||||
query = query.filter(~User.role.in_(exclude_roles))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
|
||||
@@ -7,8 +7,17 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status, Backgrou
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.socket.main import sio, get_user_ids_from_room
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
from open_webui.socket.main import (
|
||||
sio,
|
||||
get_user_ids_from_room,
|
||||
get_active_status_by_user_id,
|
||||
)
|
||||
from open_webui.models.users import (
|
||||
UserListResponse,
|
||||
UserModelResponse,
|
||||
Users,
|
||||
UserNameResponse,
|
||||
)
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.channels import (
|
||||
@@ -38,7 +47,11 @@ from open_webui.utils.chat import generate_chat_completion
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, get_users_with_access
|
||||
from open_webui.utils.access_control import (
|
||||
has_access,
|
||||
get_users_with_access,
|
||||
get_permitted_group_and_user_ids,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.channels import extract_mentions, replace_mentions
|
||||
|
||||
@@ -116,6 +129,64 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
@router.get("/{id}/users", response_model=UserListResponse)
|
||||
async def get_channel_users_by_id(
|
||||
id: str,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {
|
||||
"roles": ["!pending"],
|
||||
}
|
||||
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
permitted_ids = get_permitted_group_and_user_ids("read", channel.access_control)
|
||||
if permitted_ids:
|
||||
if permitted_ids.get("user_ids"):
|
||||
filter["user_ids"] = permitted_ids.get("user_ids")
|
||||
if permitted_ids.get("group_ids"):
|
||||
filter["group_ids"] = permitted_ids.get("group_ids")
|
||||
|
||||
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
users = result["users"]
|
||||
total = result["total"]
|
||||
|
||||
return {
|
||||
"users": [
|
||||
UserModelResponse(
|
||||
**user.model_dump(), is_active=get_active_status_by_user_id(user.id)
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChannelById
|
||||
############################
|
||||
|
||||
@@ -17,7 +17,7 @@ from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
UserGroupIdsModel,
|
||||
UserListResponse,
|
||||
UserGroupIdsListResponse,
|
||||
UserInfoListResponse,
|
||||
UserIdNameListResponse,
|
||||
UserRoleUpdateForm,
|
||||
@@ -76,7 +76,7 @@ async def get_active_users(
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
@router.get("/", response_model=UserGroupIdsListResponse)
|
||||
async def get_users(
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
|
||||
@@ -105,6 +105,22 @@ def has_permission(
|
||||
return get_permission(default_permissions, permission_hierarchy)
|
||||
|
||||
|
||||
def get_permitted_group_and_user_ids(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
) -> Union[Dict[str, List[str]], None]:
|
||||
if access_control is None:
|
||||
return None
|
||||
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
permitted_user_ids = permission_access.get("user_ids", [])
|
||||
|
||||
return {
|
||||
"group_ids": permitted_group_ids,
|
||||
"user_ids": permitted_user_ids,
|
||||
}
|
||||
|
||||
|
||||
def has_access(
|
||||
user_id: str,
|
||||
type: str = "write",
|
||||
@@ -122,9 +138,12 @@ def has_access(
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
permitted_user_ids = permission_access.get("user_ids", [])
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
if permitted_ids is None:
|
||||
return False
|
||||
|
||||
permitted_group_ids = permitted_ids.get("group_ids", [])
|
||||
permitted_user_ids = permitted_ids.get("user_ids", [])
|
||||
|
||||
return user_id in permitted_user_ids or any(
|
||||
group_id in permitted_group_ids for group_id in user_group_ids
|
||||
@@ -136,18 +155,20 @@ def get_users_with_access(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
) -> list[UserModel]:
|
||||
if access_control is None:
|
||||
result = Users.get_users()
|
||||
result = Users.get_users(filter={"roles": ["!pending"]})
|
||||
return result.get("users", [])
|
||||
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
permitted_user_ids = permission_access.get("user_ids", [])
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
if permitted_ids is None:
|
||||
return []
|
||||
|
||||
permitted_group_ids = permitted_ids.get("group_ids", [])
|
||||
permitted_user_ids = permitted_ids.get("user_ids", [])
|
||||
|
||||
user_ids_with_access = set(permitted_user_ids)
|
||||
|
||||
for group_id in permitted_group_ids:
|
||||
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
|
||||
if group_user_ids:
|
||||
user_ids_with_access.update(group_user_ids)
|
||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
|
||||
for user_ids in group_user_ids_map.values():
|
||||
user_ids_with_access.update(user_ids)
|
||||
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
|
||||
Reference in New Issue
Block a user