From a5dbd2e8dde5773dce2e48f192f94f92f036eeee Mon Sep 17 00:00:00 2001
From: Timothy Jaeryang Baek <tim@openwebui.com>
Date: Mon, 31 Mar 2025 01:10:18 -0700
Subject: [PATCH] refac: knowledge file ac

---
 backend/open_webui/routers/files.py | 86 ++++++++++++++++-------------
 1 file changed, 49 insertions(+), 37 deletions(-)

diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index 5f9704145..22e1269e3 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -24,6 +24,8 @@ from open_webui.models.files import (
     FileModelResponse,
     Files,
 )
+from open_webui.models.knowledge import Knowledges
+
 from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
 from open_webui.routers.retrieval import ProcessFileForm, process_file
 from open_webui.routers.audio import transcribe
@@ -37,10 +39,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
 router = APIRouter()
 
+
 ############################
 # Check if the current user has access to a file through any knowledge bases the user may be in.
 ############################
-async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional[str], access_type: str, user=Depends(get_verified_user)) -> bool:
+
+
+def has_access_to_file(
+    file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
+) -> bool:
     file = Files.get_file_by_id(file_id)
     log.debug(f"Checking if user has {access_type} access to file")
 
@@ -49,29 +56,20 @@ async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
+
     has_access = False
     knowledge_base_id = file.meta.get("collection_name") if file.meta else None
-    log.debug(f"Knowledge base associated with file: {knowledge_base_id}")
+
     if knowledge_base_id:
-        if access_type == "read":
-            user_access = await get_knowledge(user=user) # get_knowledge checks for read access
-        elif access_type == "write":
-            user_access = await get_knowledge_list(user=user) # get_knowledge_list checks for write access
-        else:
-            user_access = list()
-        
-        for knowledge_base in user_access:
+        knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
+            user.id, access_type
+        )
+        for knowledge_base in knowledge_bases:
             if knowledge_base.id == knowledge_base_id:
-                log.debug(f"User knowledge base with {access_type} access {knowledge_base.id} == File knowledge base {knowledge_base_id}")
                 has_access = True
                 break
 
-    
-    log.debug(f"Does user have {access_type} access to file: {has_access}")
-
     return has_access
-    
 
 
 ############################
@@ -212,10 +210,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_read_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
         return file
     else:
         raise HTTPException(
@@ -238,10 +238,12 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_read_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
         return {"content": file.data.get("content", "")}
     else:
         raise HTTPException(
@@ -270,10 +272,12 @@ async def update_file_data_content_by_id(
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_write_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "write", user)
+    ):
         try:
             process_file(
                 request,
@@ -309,10 +313,12 @@ async def get_file_content_by_id(
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_read_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
         try:
             file_path = Storage.get_file(file.path)
             file_path = Path(file_path)
@@ -375,10 +381,12 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_read_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
         try:
             file_path = Storage.get_file(file.path)
             file_path = Path(file_path)
@@ -415,10 +423,12 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_read_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
         file_path = file.path
 
         # Handle Unicode filenames
@@ -475,10 +485,12 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
             status_code=status.HTTP_404_NOT_FOUND,
             detail=ERROR_MESSAGES.NOT_FOUND,
         )
-    
-    has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
 
-    if file.user_id == user.id or user.role == "admin" or has_write_access:
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "write", user)
+    ):
         # We should add Chroma cleanup here
 
         result = Files.delete_file_by_id(id)