From c35f8c9673e0a2fe6a65b3804835545176501e85 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 16 Nov 2024 18:00:57 -0800 Subject: [PATCH] refac: knowledge access control --- .../apps/webui/routers/knowledge.py | 133 ++++++++++++++---- 1 file changed, 109 insertions(+), 24 deletions(-) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 55bb1316d..dfe233372 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + + from open_webui.env import SRC_LOG_LEVELS @@ -154,13 +157,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.get_knowledge_by_id(id=id) if knowledge: - file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) - return KnowledgeFilesResponse( - **knowledge.model_dump(), - files=files, - ) + if ( + user.role == "admin" + or knowledge.user_id == user.id + or has_access(user.id, "read", knowledge.access_control) + ): + + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -179,8 +189,20 @@ async def update_knowledge_by_id( form_data: KnowledgeUpdateForm, user=Depends(get_verified_user), ): - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] files = Files.get_files_by_ids(file_ids) @@ -212,6 +234,19 @@ def add_file_to_knowledge_by_id( user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -277,6 +312,18 @@ def update_file_from_knowledge_by_id( user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -327,6 +374,18 @@ def remove_file_from_knowledge_by_id( user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -382,13 +441,55 @@ def remove_file_from_knowledge_by_id( ) +############################ +# DeleteKnowledgeById +############################ + + +@router.delete("/{id}/delete", response_model=bool) +async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + VECTOR_DB_CLIENT.delete_collection(collection_name=id) + except Exception as e: + log.debug(e) + pass + result = Knowledges.delete_knowledge_by_id(id=id) + return result + + ############################ # ResetKnowledgeById ############################ @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)): +async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: @@ -399,19 +500,3 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)): id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []}) ) return knowledge - - -############################ -# DeleteKnowledgeById -############################ - - -@router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): - try: - VECTOR_DB_CLIENT.delete_collection(collection_name=id) - except Exception as e: - log.debug(e) - pass - result = Knowledges.delete_knowledge_by_id(id=id) - return result