diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index b22377414..d3a117a9c 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -1,14 +1,19 @@ import logging import os -import shutil import uuid from pathlib import Path from typing import Optional from pydantic import BaseModel import mimetypes +from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.files import FileForm, FileModel, Files +from open_webui.apps.webui.models.files import ( + FileForm, + FileModel, + FileModelResponse, + Files, +) from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR @@ -44,18 +49,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): id = str(uuid.uuid4()) name = filename filename = f"{id}_{filename}" - file_path = f"{UPLOAD_DIR}/{filename}" - - contents = file.file.read() - if len(contents) == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.EMPTY_CONTENT, - ) - - with open(file_path, "wb") as f: - f.write(contents) - f.close() + contents, file_path = Storage.upload_file(file.file, filename) file = Files.insert_new_file( user.id, @@ -101,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ############################ -@router.get("/", response_model=list[FileModel]) +@router.get("/", response_model=list[FileModelResponse]) async def list_files(user=Depends(get_verified_user)): if user.role == "admin": files = Files.get_files() @@ -118,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)): @router.delete("/all") async def delete_all_files(user=Depends(get_admin_user)): result = Files.delete_all_files() - if result: - folder = f"{UPLOAD_DIR}" try: - # Check if the directory exists - if os.path.exists(folder): - # Iterate over all the files and directories in the specified directory - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) # Remove the file or link - elif os.path.isdir(file_path): - shutil.rmtree(file_path) # Remove the directory - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - else: - print(f"The directory {folder} does not exist") + Storage.delete_all_files() except Exception as e: - print(f"Failed to process the directory {folder}. Reason: {e}") - + log.exception(e) + log.error(f"Error deleting files") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + ) return {"message": "All files deleted successfully"} else: raise HTTPException( @@ -222,21 +205,29 @@ async def update_file_data_content_by_id( @router.get("/{id}/content") async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - if file and (file.user_id == user.id or user.role == "admin"): - file_path = Path(file.path) + try: + file_path = Storage.get_file(file.path) + file_path = Path(file_path) - # Check if the file already exists in the cache - if file_path.is_file(): - print(f"file_path: {file_path}") - headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' - } - return FileResponse(file_path, headers=headers) - else: + # Check if the file already exists in the cache + if file_path.is_file(): + print(f"file_path: {file_path}") + headers = { + "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + } + return FileResponse(file_path, headers=headers) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + except Exception as e: + log.exception(e) + log.error(f"Error getting file content") raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=ERROR_MESSAGES.NOT_FOUND, + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), ) else: raise HTTPException( @@ -252,6 +243,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): if file and (file.user_id == user.id or user.role == "admin"): file_path = file.path if file_path: + file_path = Storage.get_file(file_path) file_path = Path(file_path) # Check if the file already exists in the cache @@ -298,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): if file and (file.user_id == user.id or user.role == "admin"): result = Files.delete_file_by_id(id) if result: + try: + Storage.delete_file(file.filename) + except Exception as e: + log.exception(e) + log.error(f"Error deleting files") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + ) return {"message": "File deleted successfully"} else: raise HTTPException( diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 6cb5b5d95..5a78ef4f3 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,6 +1,12 @@ import os import boto3 from botocore.exceptions import ClientError +import shutil + + +from typing import BinaryIO, Tuple, Optional, Union + +from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( STORAGE_PROVIDER, S3_ACCESS_KEY_ID, @@ -9,109 +15,150 @@ from open_webui.config import ( S3_REGION_NAME, S3_ENDPOINT_URL, UPLOAD_DIR, - AppConfig, ) +import boto3 +from boto3.s3 import S3Client +from botocore.exceptions import ClientError +from typing import BinaryIO, Tuple, Optional + + class StorageProvider: - def __init__(self): - self.storage_provider = None - self.client = None - self.bucket_name = None + def __init__(self, provider: Optional[str] = None): + self.storage_provider: str = provider or STORAGE_PROVIDER - if STORAGE_PROVIDER == "s3": - self.storage_provider = "s3" - self.client = boto3.client( - "s3", - region_name=S3_REGION_NAME, - endpoint_url=S3_ENDPOINT_URL, - aws_access_key_id=S3_ACCESS_KEY_ID, - aws_secret_access_key=S3_SECRET_ACCESS_KEY, - ) - self.bucket_name = S3_BUCKET_NAME - else: - self.storage_provider = "local" + self.s3_client = None + self.s3_bucket_name: Optional[str] = None - def get_storage_provider(self): - return self.storage_provider - - def upload_file(self, file, filename): if self.storage_provider == "s3": - try: - bucket = self.bucket_name - self.client.upload_fileobj(file, bucket, filename) - return filename - except ClientError as e: - raise RuntimeError(f"Error uploading file: {e}") - else: - file_path = os.path.join(UPLOAD_DIR, filename) - os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "wb") as f: - f.write(file.read()) - return filename + self._initialize_s3() - def list_files(self): - if self.storage_provider == "s3": - try: - bucket = self.bucket_name - response = self.client.list_objects_v2(Bucket=bucket) - if "Contents" in response: - return [content["Key"] for content in response["Contents"]] - return [] - except ClientError as e: - raise RuntimeError(f"Error listing files: {e}") - else: - return [ - f - for f in os.listdir(UPLOAD_DIR) - if os.path.isfile(os.path.join(UPLOAD_DIR, f)) - ] + def _initialize_s3(self) -> None: + """Initializes the S3 client and bucket name if using S3 storage.""" + self.s3_client = boto3.client( + "s3", + region_name=S3_REGION_NAME, + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + ) + self.bucket_name = S3_BUCKET_NAME - def get_file(self, filename): - if self.storage_provider == "s3": - try: - bucket = self.bucket_name - file_path = f"/tmp/{filename}" - self.client.download_file(bucket, filename, file_path) - return file_path - except ClientError as e: - raise RuntimeError(f"Error downloading file: {e}") - else: - file_path = os.path.join(UPLOAD_DIR, filename) - if os.path.isfile(file_path): - return file_path - else: - raise FileNotFoundError(f"File {filename} not found in local storage.") + def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to S3 storage.""" + if not self.s3_client: + raise RuntimeError("S3 Client is not initialized.") - def delete_file(self, filename): - if self.storage_provider == "s3": - try: - bucket = self.bucket_name - self.client.delete_object(Bucket=bucket, Key=filename) - except ClientError as e: - raise RuntimeError(f"Error deleting file: {e}") - else: - file_path = os.path.join(UPLOAD_DIR, filename) - if os.path.isfile(file_path): - os.remove(file_path) - else: - raise FileNotFoundError(f"File {filename} not found in local storage.") + try: + self.s3_client.upload_fileobj(file, self.bucket_name, filename) + return file.read(), f"s3://{self.bucket_name}/{filename}" + except ClientError as e: + raise RuntimeError(f"Error uploading file to S3: {e}") - def delete_all_files(self): - if self.storage_provider == "s3": - try: - bucket = self.bucket_name - response = self.client.list_objects_v2(Bucket=bucket) - if "Contents" in response: - for content in response["Contents"]: - self.client.delete_object(Bucket=bucket, Key=content["Key"]) - except ClientError as e: - raise RuntimeError(f"Error deleting all files: {e}") + def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to local storage.""" + file_path = f"{UPLOAD_DIR}/{filename}" + with open(file_path, "wb") as f: + f.write(contents) + return contents, file_path + + def _get_file_from_s3(self, file_path: str) -> str: + """Handles downloading of the file from S3 storage.""" + if not self.s3_client: + raise RuntimeError("S3 Client is not initialized.") + + try: + bucket_name, key = file_path.split("//")[1].split("/") + local_file_path = f"{UPLOAD_DIR}/{key}" + self.s3_client.download_file(bucket_name, key, local_file_path) + return local_file_path + except ClientError as e: + raise RuntimeError(f"Error downloading file from S3: {e}") + + def _get_file_from_local(self, file_path: str) -> str: + """Handles downloading of the file from local storage.""" + return file_path + + def _delete_from_s3(self, filename: str) -> None: + """Handles deletion of the file from S3 storage.""" + if not self.s3_client: + raise RuntimeError("S3 Client is not initialized.") + + try: + self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) + except ClientError as e: + raise RuntimeError(f"Error deleting file from S3: {e}") + + def _delete_from_local(self, filename: str) -> None: + """Handles deletion of the file from local storage.""" + file_path = f"{UPLOAD_DIR}/{filename}" + if os.path.isfile(file_path): + os.remove(file_path) else: + raise FileNotFoundError(f"File {filename} not found in local storage.") + + def _delete_all_from_s3(self) -> None: + """Handles deletion of all files from S3 storage.""" + if not self.s3_client: + raise RuntimeError("S3 Client is not initialized.") + + try: + response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) + if "Contents" in response: + for content in response["Contents"]: + self.s3_client.delete_object( + Bucket=self.bucket_name, Key=content["Key"] + ) + except ClientError as e: + raise RuntimeError(f"Error deleting all files from S3: {e}") + + def _delete_all_from_local(self) -> None: + """Handles deletion of all files from local storage.""" + if os.path.exists(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR): file_path = os.path.join(UPLOAD_DIR, filename) - if os.path.isfile(file_path): - os.remove(file_path) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) # Remove the file or link + elif os.path.isdir(file_path): + shutil.rmtree(file_path) # Remove the directory + except Exception as e: + print(f"Failed to delete {file_path}. Reason: {e}") + else: + raise FileNotFoundError( + f"Directory {UPLOAD_DIR} not found in local storage." + ) + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Uploads a file either to S3 or the local file system.""" + contents = file.read() + if not contents: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) + + if self.storage_provider == "s3": + return self._upload_to_s3(file, filename) + return self._upload_to_local(contents, filename) + + def get_file(self, file_path: str) -> str: + """Downloads a file either from S3 or the local file system and returns the file path.""" + if self.storage_provider == "s3": + return self._get_file_from_s3(file_path) + return self._get_file_from_local(file_path) + + def delete_file(self, filename: str) -> None: + """Deletes a file either from S3 or the local file system.""" + if self.storage_provider == "s3": + self._delete_from_s3(filename) + else: + self._delete_from_local(filename) + + def delete_all_files(self) -> None: + """Deletes all files from the storage.""" + if self.storage_provider == "s3": + self._delete_all_from_s3() + else: + self._delete_all_from_local() -Storage = StorageProvider() +Storage = StorageProvider(provider=STORAGE_PROVIDER)