From 4519ddd0e97b33114d5ad0d012ed4523937f55c7 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 23 Aug 2024 16:19:04 +0200 Subject: [PATCH] refac: files rbac --- backend/apps/webui/models/files.py | 7 +++++++ backend/apps/webui/routers/files.py | 14 ++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index 2de5c33b5..e1d1cec9f 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -98,6 +98,13 @@ class FilesTable: return [FileModel.model_validate(file) for file in db.query(File).all()] + def get_files_by_user_id(self, user_id: str) -> list[FileModel]: + with get_db() as db: + return [ + FileModel.model_validate(file) + for file in db.query(File).filter_by(user_id=user_id).all() + ] + def delete_file_by_id(self, id: str) -> bool: with get_db() as db: diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index ba571fc71..48ca366d8 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -106,7 +106,10 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): @router.get("/", response_model=list[FileModel]) async def list_files(user=Depends(get_verified_user)): - files = Files.get_files() + if user.role == "admin": + files = Files.get_files() + else: + files = Files.get_files_by_user_id(user.id) return files @@ -156,7 +159,7 @@ async def delete_all_files(user=Depends(get_admin_user)): async def get_file_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - if file: + if file and (file.user_id == user.id or user.role == "admin"): return file else: raise HTTPException( @@ -174,7 +177,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - if file: + if file and (file.user_id == user.id or user.role == "admin"): file_path = Path(file.meta["path"]) # Check if the file already exists in the cache @@ -197,7 +200,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - if file: + if file and (file.user_id == user.id or user.role == "admin"): file_path = Path(file.meta["path"]) # Check if the file already exists in the cache @@ -224,8 +227,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.delete("/{id}") async def delete_file_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - - if file: + if file and (file.user_id == user.id or user.role == "admin"): result = Files.delete_file_by_id(id) if result: return {"message": "File deleted successfully"}