diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index f2f364c14..0859c053a 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -22,6 +22,7 @@ from sqlalchemy import ( ForeignKey, cast, or_, + select, ) @@ -99,6 +100,16 @@ class GroupResponse(GroupModel): member_count: Optional[int] = None +class GroupInfoResponse(BaseModel): + id: str + user_id: str + name: str + description: str + member_count: Optional[int] = None + created_at: int + updated_at: int + + class GroupForm(BaseModel): name: str description: str @@ -181,14 +192,12 @@ class GroupTable: if member_id: # Also include member-only groups where user is a member - member_groups_subq = ( - db.query(GroupMember.group_id) - .filter(GroupMember.user_id == member_id) - .subquery() + member_groups_select = select(GroupMember.group_id).where( + GroupMember.user_id == member_id ) members_only_and_is_member = and_( json_share_lower == "members", - Group.id.in_(member_groups_subq), + Group.id.in_(member_groups_select), ) query = query.filter( or_(anyone_can_share, members_only_and_is_member) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 2c94ddd91..4713e369a 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -36,6 +36,7 @@ from open_webui.models.channels import ( ChannelWebhookModel, ChannelWebhookForm, ) +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.models.messages import ( Messages, MessageModel, @@ -60,12 +61,7 @@ from open_webui.utils.chat import generate_chat_completion from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import ( - has_access, - get_users_with_access, - get_permitted_group_and_user_ids, - has_permission, -) +from open_webui.utils.access_control import has_permission from open_webui.utils.webhook import post_webhook from open_webui.utils.channels import extract_mentions, replace_mentions from open_webui.internal.db import get_session @@ -76,6 +72,66 @@ log = logging.getLogger(__name__) router = APIRouter() +def channel_has_access( + user_id: str, + channel: ChannelModel, + permission: str = "read", + strict: bool = True, + db: Optional[Session] = None, +) -> bool: + if AccessGrants.has_access( + user_id=user_id, + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ): + return True + + if ( + not strict + and permission == "write" + and has_public_read_access_grant(channel.access_grants) + ): + return True + + return False + + +def get_channel_users_with_access( + channel: ChannelModel, permission: str = "read", db: Optional[Session] = None +): + return AccessGrants.get_users_with_access( + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ) + + +def get_channel_permitted_group_and_user_ids( + channel: ChannelModel, permission: str = "read" +) -> Optional[dict[str, list[str]]]: + if permission == "read" and has_public_read_access_grant(channel.access_grants): + return None + + user_ids = [] + group_ids = [] + + for grant in channel.access_grants: + if grant.permission != permission: + continue + if grant.principal_type == "group": + group_ids.append(grant.principal_id) + elif grant.principal_type == "user" and grant.principal_id != "*": + user_ids.append(grant.principal_id) + + return { + "user_ids": list(dict.fromkeys(user_ids)), + "group_ids": list(dict.fromkeys(group_ids)), + } + + ############################ # Channels Enabled Dependency ############################ @@ -418,22 +474,22 @@ async def get_channel_by_id( } ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - write_access = has_access( + write_access = channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) - user_count = len(get_users_with_access("read", channel.access_control)) + user_count = len(get_channel_users_with_access(channel, "read", db=db)) channel_member = Channels.get_member_by_channel_and_user_id( channel.id, user.id, db=db @@ -527,8 +583,8 @@ async def get_channel_members_by_id( filter["channel_id"] = channel.id else: filter["roles"] = ["!pending"] - permitted_ids = get_permitted_group_and_user_ids( - "read", channel.access_control + permitted_ids = get_channel_permitted_group_and_user_ids( + channel, permission="read" ) if permitted_ids: filter["user_ids"] = permitted_ids.get("user_ids") @@ -811,8 +867,8 @@ async def get_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -888,8 +944,8 @@ async def get_pinned_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -946,7 +1002,7 @@ async def get_pinned_channel_messages( async def send_notification( name, webui_url, channel, message, active_user_ids, db=None ): - users = get_users_with_access("read", channel.access_control) + users = get_channel_users_with_access(channel, "read", db=db) for user in users: if (user.id not in active_user_ids) and Channels.is_user_channel_member( @@ -1173,10 +1229,10 @@ async def new_message_handler( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1318,8 +1374,8 @@ async def get_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1372,8 +1428,8 @@ async def get_channel_message_data( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1426,8 +1482,8 @@ async def pin_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1492,8 +1548,8 @@ async def get_channel_thread_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1577,8 +1633,8 @@ async def update_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + and not channel_has_access( + user.id, channel, permission="read", db=db ) ): raise HTTPException( @@ -1644,10 +1700,10 @@ async def add_reaction_to_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1723,10 +1779,10 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1818,10 +1874,10 @@ async def delete_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( + and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 8b17dc406..4a12db5cd 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -38,6 +38,7 @@ from open_webui.models.files import ( from open_webui.models.chats import Chats from open_webui.models.knowledge import Knowledges from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.routers.retrieval import ProcessFileForm, process_file @@ -47,7 +48,6 @@ from open_webui.storage.provider import Storage from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.misc import strict_match_mime_type from pydantic import BaseModel @@ -82,8 +82,13 @@ def has_access_to_file( group.id for group in Groups.get_groups_by_member_id(user.id, db=db) } for knowledge_base in knowledge_bases: - if knowledge_base.user_id == user.id or has_access( - user.id, access_type, knowledge_base.access_control, user_group_ids, db=db + if knowledge_base.user_id == user.id or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=access_type, + user_group_ids=user_group_ids, + db=db, ): return True diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index cc0cb8f5a..a46e05473 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -7,6 +7,7 @@ from open_webui.models.users import Users, UserInfoResponse from open_webui.models.groups import ( Groups, GroupForm, + GroupInfoResponse, GroupUpdateForm, GroupResponse, UserIdsForm, @@ -104,6 +105,23 @@ async def get_group_by_id( ) +@router.get("/id/{id}/info", response_model=Optional[GroupInfoResponse]) +async def get_group_info_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) + if group: + return GroupInfoResponse( + **group.model_dump(), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # ExportGroupById ############################ diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 96136d689..8350fda03 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -29,7 +29,8 @@ from open_webui.storage.provider import Storage from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_verified_user, get_admin_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL @@ -133,8 +134,12 @@ async def get_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -180,8 +185,12 @@ async def search_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -243,14 +252,14 @@ async def create_new_knowledge( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, ) ): - form_data.access_control = {} + form_data.access_grants = [] knowledge = Knowledges.insert_new_knowledge(user.id, form_data) @@ -387,7 +396,13 @@ async def get_knowledge_by_id( if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): return KnowledgeFilesResponse( @@ -395,7 +410,13 @@ async def get_knowledge_by_id( write_access=( user.id == knowledge.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access(user.id, "write", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) ), ) else: @@ -435,7 +456,12 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + ) and user.role != "admin" ): raise HTTPException( @@ -446,14 +472,14 @@ async def update_knowledge_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, ) ): - form_data.access_control = {} + form_data.access_grants = [] knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: @@ -502,7 +528,13 @@ async def get_knowledge_files_by_id( if not ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -555,7 +587,13 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -624,7 +662,13 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): @@ -693,7 +737,13 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -770,7 +820,13 @@ async def delete_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -802,7 +858,7 @@ async def delete_knowledge_by_id( base_model_id=model.base_model_id, meta=model.meta, params=model.params, - access_control=model.access_control, + access_grants=model.access_grants, is_active=model.is_active, ) Models.update_model_by_id(model.id, model_form, db=db) @@ -839,7 +895,13 @@ async def reset_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -882,7 +944,13 @@ async def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 85f0fb4f6..b2bccc195 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -15,6 +15,7 @@ from open_webui.models.models import ( ModelAccessResponse, Models, ) +from open_webui.models.access_grants import AccessGrants from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES @@ -30,7 +31,7 @@ from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -98,7 +99,13 @@ async def get_models( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) for model in result.items @@ -315,14 +322,26 @@ async def get_model_by_id( if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id - or has_access(user.id, "read", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="read", + db=db, + ) ): return ModelAccessResponse( **model.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) else: @@ -393,7 +412,13 @@ async def toggle_model_by_id( if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): model = Models.toggle_model_by_id(id, db=db) @@ -436,7 +461,13 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -471,7 +502,13 @@ async def delete_model_by_id( if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 56730e2b6..321b06fcd 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -27,7 +27,8 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -200,8 +201,12 @@ async def get_note_by_id( if user.role != "admin" and ( user.id != note.user_id and ( - not has_access( - user.id, type="read", access_control=note.access_control, db=db + not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + db=db, ) ) ): @@ -212,13 +217,14 @@ async def get_note_by_id( write_access = ( user.role == "admin" or (user.id == note.user_id) - or has_access( - user.id, - type="write", - access_control=note.access_control, - strict=False, + or AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", db=db, ) + or has_public_read_access_grant(note.access_grants) ) return NoteResponse(**note.model_dump(), write_access=write_access) @@ -253,8 +259,12 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( @@ -264,7 +274,7 @@ async def update_note_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_notes", @@ -272,7 +282,7 @@ async def update_note_by_id( db=db, ) ): - form_data.access_control = {} + form_data.access_grants = [] try: note = Notes.update_note_by_id(id, form_data, db=db) @@ -318,8 +328,12 @@ async def delete_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index cfbdbcac0..16e0d2795 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -44,6 +44,7 @@ from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.utils.misc import ( calculate_sha256, ) @@ -53,9 +54,6 @@ from open_webui.utils.payload import ( apply_system_prompt_to_body, ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access - - from open_webui.config import ( UPLOAD_DIR, ) @@ -430,8 +428,12 @@ async def get_filtered_models(models, user, db=None): for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ): filtered_models.append(model) return filtered_models @@ -1259,7 +1261,7 @@ async def generate_chat_completion( bypass_system_prompt: bool = False, ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. if BYPASS_MODEL_ACCESS_CONTROL: @@ -1306,10 +1308,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1383,7 +1386,7 @@ async def generate_openai_completion( user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. metadata = form_data.pop("metadata", None) @@ -1418,10 +1421,11 @@ async def generate_openai_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1468,7 +1472,7 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. metadata = form_data.pop("metadata", None) @@ -1507,10 +1511,11 @@ async def generate_openai_chat_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1608,10 +1613,11 @@ async def get_openai_models( for model in models: model_info = Models.get_model_by_id(model["id"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, - type="read", - access_control=model_info.access_control, + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", db=db, ): filtered_models.append(model) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index b1d31afb8..cfb55ee80 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import Session from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.config import ( CACHE_DIR, ) @@ -50,7 +51,6 @@ from open_webui.utils.misc import ( ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.headers import include_user_info_headers @@ -462,8 +462,12 @@ async def get_filtered_models(models, user, db=None): for model in models.get("data", []): model_info = Models.get_model_by_id(model["id"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ): filtered_models.append(model) return filtered_models @@ -876,7 +880,7 @@ async def generate_chat_completion( bypass_system_prompt: bool = False, ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. if BYPASS_MODEL_ACCESS_CONTROL: @@ -914,10 +918,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index fc24ccaf4..77c2e84fc 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -9,6 +9,7 @@ from open_webui.models.prompts import ( PromptModel, Prompts, ) +from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups from open_webui.models.prompt_history import ( PromptHistories, @@ -17,7 +18,7 @@ from open_webui.models.prompt_history import ( ) from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -115,7 +116,13 @@ async def get_prompt_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) for prompt in result.items @@ -186,14 +193,26 @@ async def get_prompt_by_command( if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) @@ -218,14 +237,26 @@ async def get_prompt_by_id( if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) @@ -258,7 +289,13 @@ async def update_prompt_by_id( # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -311,7 +348,13 @@ async def update_prompt_metadata( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -356,7 +399,13 @@ async def set_prompt_version( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -395,7 +444,13 @@ async def delete_prompt_by_id( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -434,7 +489,13 @@ async def get_prompt_history( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -469,7 +530,13 @@ async def get_prompt_history_entry( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -508,7 +575,13 @@ async def delete_prompt_history_entry( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -553,7 +626,13 @@ async def get_prompt_diff( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 7f9b23c7c..015bde232 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -21,6 +21,7 @@ from open_webui.models.tools import ( ToolAccessResponse, Tools, ) +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import ( load_tool_module_by_id, replace_imports, @@ -156,7 +157,24 @@ async def get_tools( tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control, user_group_ids, db=db) + or ( + has_access( + user.id, + "read", + getattr(tool, "access_control", None), + user_group_ids, + db=db, + ) + if str(tool.id).startswith("server:") + else AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + ) ] return tools @@ -181,7 +199,13 @@ async def get_tool_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tool.user_id - or has_access(user.id, "write", tool.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="write", + db=db, + ) ), ) for tool in tools @@ -382,14 +406,26 @@ async def get_tools_by_id( if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="read", + db=db, + ) ): return ToolAccessResponse( **tools.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tools.user_id - or has_access(user.id, "write", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) ), ) else: @@ -427,7 +463,13 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -489,7 +531,13 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -588,7 +636,13 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 20b69bcdf..5d8a1057d 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -19,7 +19,7 @@ from open_webui.models.users import ( UserModel, UserGroupIdsModel, UserGroupIdsListResponse, - UserInfoListResponse, + UserInfoResponse, UserInfoListResponse, UserRoleUpdateForm, UserStatus, @@ -446,7 +446,7 @@ class UserActiveResponse(UserStatus): @router.get("/{user_id}", response_model=UserActiveResponse) async def get_user_by_id( - user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) ): # Check if user_id is a shared chat # If it is, get the user_id from the chat @@ -478,6 +478,20 @@ async def get_user_by_id( ) +@router.get("/{user_id}/info", response_model=UserInfoResponse) +async def get_user_info_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user_id, db=db) + if user: + return user + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + @router.get("/{user_id}/oauth/sessions") async def get_user_oauth_sessions_by_id( user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 67e04e69c..e987f9c29 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -42,7 +42,7 @@ from open_webui.utils.auth import decode_token from open_webui.socket.utils import RedisDict, RedisLock, YdocManager from open_webui.tasks import create_task, stop_item_tasks from open_webui.utils.redis import get_redis_connection -from open_webui.utils.access_control import has_access, get_users_with_access +from open_webui.models.access_grants import AccessGrants from open_webui.env import ( @@ -405,7 +405,12 @@ async def join_note(sid, data): if ( user.role != "admin" and user.id != note.user_id - and not has_access(user.id, type="read", access_control=note.access_control) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): log.error(f"User {user.id} does not have access to note {data['note_id']}") return @@ -467,8 +472,11 @@ async def ydoc_document_join(sid, data): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error( @@ -537,8 +545,11 @@ async def document_save_handler(document_id, data, user): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error(f"User {user.get('id')} does not have access to note {note_id}") diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index 78d64faaf..f02fbf60a 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -743,10 +743,14 @@ async def view_note( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "read", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Access denied"}) @@ -797,7 +801,7 @@ async def write_note( form = NoteForm( title=title, data={"content": {"md": content}}, - access_control={}, # Private by default - only owner can access + access_grants=[], # Private by default - only owner can access ) new_note = Notes.insert_new_note(user_id, form) @@ -852,10 +856,14 @@ async def replace_note_content( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "write", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="write", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Write access denied"}) @@ -1532,7 +1540,7 @@ async def view_knowledge_file( try: from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1551,8 +1559,12 @@ async def view_knowledge_file( if ( user_role == "admin" or knowledge_base.user_id == user_id - or has_access( - user_id, "read", knowledge_base.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): has_knowledge_access = True @@ -1631,7 +1643,7 @@ async def query_knowledge_files( from open_webui.models.files import Files from open_webui.models.notes import Notes from open_webui.retrieval.utils import query_collection - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1656,8 +1668,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(item_id) @@ -1674,7 +1690,12 @@ async def query_knowledge_files( if note and ( user_role == "admin" or note.user_id == user_id - or has_access(user_id, "read", note.access_control) + or AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): content = note.data.get("content", {}).get("md", "") note_results.append( @@ -1693,8 +1714,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(knowledge_id) diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index b3a332ade..4224605f1 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -13,6 +13,7 @@ from open_webui.functions import get_function_models from open_webui.models.functions import Functions from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups @@ -354,8 +355,12 @@ def check_model_access(user, model, db=None): raise Exception("Model not found") elif not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ) ): raise Exception("Model not found") @@ -395,11 +400,13 @@ def get_filtered_models(models, user, db=None): if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", user_group_ids=user_group_ids, + db=db, ) ): filtered_models.append(model) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index b2f95a829..645c7799e 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -38,6 +38,7 @@ from open_webui.utils.misc import is_string_allowed from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import load_tool_module_by_id from open_webui.utils.access_control import has_access from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL @@ -165,7 +166,13 @@ async def get_tools( if ( not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) and tool.user_id != user.id - and not has_access(user.id, "read", tool.access_control, user_group_ids) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + ) ): log.warning(f"Access denied to tool {tool_id} for user {user.id}") continue