From 2ab5b2fd71a4e08878520bb35e3e11679f2c874e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 15 Nov 2024 01:29:07 -0800 Subject: [PATCH] wip: access control backend --- .../open_webui/apps/webui/models/groups.py | 11 +- .../open_webui/apps/webui/models/models.py | 35 ++++++ .../open_webui/apps/webui/models/prompts.py | 60 ++++++++++ .../open_webui/apps/webui/routers/models.py | 106 +++++++++++------- .../open_webui/apps/webui/routers/users.py | 26 ++++- backend/open_webui/config.py | 28 ++++- backend/open_webui/main.py | 2 +- backend/open_webui/utils/utils.py | 66 ++++++++++- 8 files changed, 282 insertions(+), 52 deletions(-) diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/apps/webui/models/groups.py index 5d6258bc7..89717e16b 100644 --- a/backend/open_webui/apps/webui/models/groups.py +++ b/backend/open_webui/apps/webui/models/groups.py @@ -68,7 +68,6 @@ class GroupResponse(BaseModel): permissions: Optional[dict] = None meta: Optional[dict] = None user_ids: list[str] = [] - admin_ids: list[str] = [] created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -119,6 +118,16 @@ class GroupTable: for group in db.query(Group).order_by(Group.updated_at.desc()).all() ] + def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group) + .filter(Group.user_ids.contains([user_id])) + .order_by(Group.updated_at.desc()) + .all() + ] + def get_group_by_id(self, id: str) -> Optional[GroupModel]: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 66316001c..77b7c5f67 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -4,9 +4,20 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.groups import Groups + + from pydantic import BaseModel, ConfigDict + +from sqlalchemy import or_, and_, func +from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy import BigInteger, Column, Text, JSON + +from open_webui.utils.utils import has_access + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -112,8 +123,14 @@ class ModelModel(BaseModel): class ModelResponse(BaseModel): id: str + user_id: str + base_model_id: Optional[str] = None + name: str + params: ModelParams meta: ModelMeta + + access_control: Optional[dict] = None updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -157,6 +174,24 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] + def get_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id != None).all() + ] + + def get_models_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ModelModel]: + models = self.get_all_models() + return [ + model + for model in models + if model.user_id == user_id + or has_access(user_id, permission, model.access_control) + ] + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index 689891fe0..7f96d5374 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -2,6 +2,8 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.groups import Groups + from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -100,6 +102,64 @@ class PromptsTable: PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() ] + def get_prompts_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[PromptModel]: + prompts = self.get_prompts() + + groups = Groups.get_groups_by_member_id(user_id) + group_ids = [group.id for group in groups] + + if permission == "write": + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or ( + prompt.access_control + and ( + any( + group_id + in prompt.access_control.get(permission, {}).get( + "group_ids", [] + ) + for group_id in group_ids + ) + or ( + user_id + in prompt.access_control.get(permission, {}).get( + "user_ids", [] + ) + ) + ) + ) + ] + elif permission == "read": + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or prompt.access_control is None + or ( + prompt.access_control + and ( + any( + prompt.access_control.get(permission, {}).get( + "group_ids", [] + ) + in group_id + for group_id in group_ids + ) + or ( + user_id + in prompt.access_control.get(permission, {}).get( + "user_ids", [] + ) + ) + ) + ) + ] + def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index a5cb2395e..906defb76 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import ( ) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user + + +from open_webui.utils.utils import get_admin_user, get_verified_user, has_access router = APIRouter() + ########################### -# getModels +# GetModels ########################### @router.get("/", response_model=list[ModelResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if id: - model = Models.get_model_by_id(id) - if model: - return [model] - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + return Models.get_models() else: - return Models.get_all_models() + return Models.get_models_by_user_id(user.id) ############################ -# AddNewModel +# CreateNewModel ############################ -@router.post("/add", response_model=Optional[ModelModel]) -async def add_new_model( - request: Request, +@router.post("/create", response_model=Optional[ModelModel]) +async def create_new_model( form_data: ModelForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): - if form_data.id in request.app.state.MODELS: + + model = Models.get_model_by_id(form_data.id) + if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) + else: model = Models.insert_new_model(form_data, user.id) - if model: return model else: @@ -60,37 +57,49 @@ async def add_new_model( ) +########################### +# GetModelById +########################### + + +@router.get("/id/{id}", response_model=Optional[ModelResponse]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "read", model.access_control) + ): + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateModelById ############################ -@router.post("/update", response_model=Optional[ModelModel]) +@router.post("/id/{id}/update", response_model=Optional[ModelModel]) async def update_model_by_id( - request: Request, id: str, form_data: ModelForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): model = Models.get_model_by_id(id) - if model: - model = Models.update_model_by_id(id, form_data) - return model - else: - if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(form_data, user.id) - if model: - return model - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) + + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + model = Models.update_model_by_id(id, form_data) + return model ############################ @@ -98,7 +107,20 @@ async def update_model_by_id( ############################ -@router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user)): +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if model.user_id != user.id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + result = Models.delete_model_by_id(id) return result diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/apps/webui/routers/users.py index abc540efa..3cd7166ad 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/apps/webui/routers/users.py @@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) ############################ -@router.get("/permissions/user") +class WorkspacePermissions(BaseModel): + models: bool + knowledge: bool + prompts: bool + tools: bool + + +class ChatPermissions(BaseModel): + delete: bool + edit: bool + temporary: bool + + +class UserPermissions(BaseModel): + workspace: WorkspacePermissions + chat: ChatPermissions + + +@router.get("/permissions") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): return request.app.state.config.USER_PERMISSIONS -@router.post("/permissions/user") +@router.post("/permissions") async def update_user_permissions( - request: Request, form_data: dict, user=Depends(get_admin_user) + request: Request, form_data: UserPermissions, user=Depends(get_admin_user) ): - request.app.state.config.USER_PERMISSIONS = form_data + request.app.state.config.USER_PERMISSIONS = form_data.model_dump() return request.app.state.config.USER_PERMISSIONS diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 4362eae0b..c1070dbbd 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -739,6 +739,26 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) + +USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" +) + USER_PERMISSIONS_CHAT_DELETE = ( os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" ) @@ -755,11 +775,17 @@ USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", "user.permissions", { + "workspace": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + }, "chat": { "deletion": USER_PERMISSIONS_CHAT_DELETE, "editing": USER_PERMISSIONS_CHAT_EDIT, "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - } + }, }, ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index faef9e81c..53dc01cf8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -993,7 +993,7 @@ async def get_all_models(): models.append( { - "id": custom_model.id, + "id": f"open-webui-{custom_model.id}", "name": custom_model.name, "object": "model", "created": custom_model.created_at, diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/utils.py index 31fe227ed..370a30d6f 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/utils.py @@ -1,12 +1,18 @@ import logging import uuid -from datetime import UTC, datetime, timedelta -from typing import Optional, Union - import jwt + +from datetime import UTC, datetime, timedelta +from typing import Optional, Union, List, Dict + + from open_webui.apps.webui.models.users import Users +from open_webui.apps.webui.models.groups import Groups + from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY + + from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext @@ -147,3 +153,57 @@ def get_admin_user(user=Depends(get_current_user)): detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) return user + + +def has_permission( + user_id: str, + permission_key: str, + default_permissions: Dict[str, bool] = {}, +) -> bool: + """ + Check if a user has a specific permission by checking the group permissions + and falls back to default permissions if not found in any group. + + Permission keys can be hierarchical and separated by dots ('.'). + """ + + def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool: + """Traverse permissions dict using a list of keys (from dot-split permission_key).""" + for key in keys: + if key not in permissions: + return False # If any part of the hierarchy is missing, deny access + permissions = permissions[key] # Go one level deeper + + return bool(permissions) # Return the boolean at the final level + + permission_hierarchy = permission_key.split(".") + + # Retrieve user group permissions + user_groups = Groups.get_groups_by_member_id(user_id) + + for group in user_groups: + group_permissions = group.permissions + if get_permission(group_permissions, permission_hierarchy): + return True + + # Check default permissions afterwards if the group permissions don't allow it + return get_permission(default_permissions, permission_hierarchy) + + +def has_access( + user_id: str, + action: str = "write", + access_control: Optional[dict] = None, +) -> bool: + if access_control is None: + return action == "read" + + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = [group.id for group in user_groups] + permission_access = access_control.get(action, {}) + permitted_group_ids = permission_access.get("group_ids", []) + permitted_user_ids = permission_access.get("user_ids", []) + + return user_id in permitted_user_ids or any( + group_id in permitted_group_ids for group_id in user_group_ids + )