diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 5f9704145..22e1269e3 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -24,6 +24,8 @@ from open_webui.models.files import ( FileModelResponse, Files, ) +from open_webui.models.knowledge import Knowledges + from open_webui.routers.knowledge import get_knowledge, get_knowledge_list from open_webui.routers.retrieval import ProcessFileForm, process_file from open_webui.routers.audio import transcribe @@ -37,10 +39,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() + ############################ # Check if the current user has access to a file through any knowledge bases the user may be in. ############################ -async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional[str], access_type: str, user=Depends(get_verified_user)) -> bool: + + +def has_access_to_file( + file_id: Optional[str], access_type: str, user=Depends(get_verified_user) +) -> bool: file = Files.get_file_by_id(file_id) log.debug(f"Checking if user has {access_type} access to file") @@ -49,29 +56,20 @@ async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - + has_access = False knowledge_base_id = file.meta.get("collection_name") if file.meta else None - log.debug(f"Knowledge base associated with file: {knowledge_base_id}") + if knowledge_base_id: - if access_type == "read": - user_access = await get_knowledge(user=user) # get_knowledge checks for read access - elif access_type == "write": - user_access = await get_knowledge_list(user=user) # get_knowledge_list checks for write access - else: - user_access = list() - - for knowledge_base in user_access: + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( + user.id, access_type + ) + for knowledge_base in knowledge_bases: if knowledge_base.id == knowledge_base_id: - log.debug(f"User knowledge base with {access_type} access {knowledge_base.id} == File knowledge base {knowledge_base_id}") has_access = True break - - log.debug(f"Does user have {access_type} access to file: {has_access}") - return has_access - ############################ @@ -212,10 +210,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user) - if file.user_id == user.id or user.role == "admin" or has_read_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): return file else: raise HTTPException( @@ -238,10 +238,12 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user) - if file.user_id == user.id or user.role == "admin" or has_read_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): return {"content": file.data.get("content", "")} else: raise HTTPException( @@ -270,10 +272,12 @@ async def update_file_data_content_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user) - if file.user_id == user.id or user.role == "admin" or has_write_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "write", user) + ): try: process_file( request, @@ -309,10 +313,12 @@ async def get_file_content_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user) - if file.user_id == user.id or user.role == "admin" or has_read_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): try: file_path = Storage.get_file(file.path) file_path = Path(file_path) @@ -375,10 +381,12 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user) - if file.user_id == user.id or user.role == "admin" or has_read_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): try: file_path = Storage.get_file(file.path) file_path = Path(file_path) @@ -415,10 +423,12 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user) - if file.user_id == user.id or user.role == "admin" or has_read_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): file_path = file.path # Handle Unicode filenames @@ -475,10 +485,12 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - - has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user) - if file.user_id == user.id or user.role == "admin" or has_write_access: + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "write", user) + ): # We should add Chroma cleanup here result = Files.delete_file_by_id(id)