From a3f737c0c6e1c15729db45d7e7990a3891aeacc8 Mon Sep 17 00:00:00 2001 From: Rodrigo Agundez Date: Wed, 15 Jan 2025 23:16:38 +0800 Subject: [PATCH] Split the storage providers into separate classes in preparation for other storage providers like GCS --- backend/open_webui/storage/provider.py | 203 +++++++++++-------------- 1 file changed, 90 insertions(+), 113 deletions(-) diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index ae3347682..3280e6519 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,121 +1,68 @@ import os +import shutil +from abc import ABC, abstractmethod +from typing import BinaryIO, Tuple + import boto3 from botocore.exceptions import ClientError -import shutil - - -from typing import BinaryIO, Tuple, Optional, Union - -from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( - STORAGE_PROVIDER, S3_ACCESS_KEY_ID, - S3_SECRET_ACCESS_KEY, S3_BUCKET_NAME, - S3_REGION_NAME, S3_ENDPOINT_URL, + S3_REGION_NAME, + S3_SECRET_ACCESS_KEY, + STORAGE_PROVIDER, UPLOAD_DIR, ) +from open_webui.constants import ERROR_MESSAGES -import boto3 -from botocore.exceptions import ClientError -from typing import BinaryIO, Tuple, Optional +class StorageProvider(ABC): + @abstractmethod + def get_file(self, file_path: str) -> str: + pass + + @abstractmethod + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + pass + + @abstractmethod + def delete_all_files(self) -> None: + pass + + @abstractmethod + def delete_file(self, file_path: str) -> None: + pass -class StorageProvider: - def __init__(self, provider: Optional[str] = None): - self.storage_provider: str = provider or STORAGE_PROVIDER - - self.s3_client = None - self.s3_bucket_name: Optional[str] = None - - if self.storage_provider == "s3": - self._initialize_s3() - - def _initialize_s3(self) -> None: - """Initializes the S3 client and bucket name if using S3 storage.""" - self.s3_client = boto3.client( - "s3", - region_name=S3_REGION_NAME, - endpoint_url=S3_ENDPOINT_URL, - aws_access_key_id=S3_ACCESS_KEY_ID, - aws_secret_access_key=S3_SECRET_ACCESS_KEY, - ) - self.bucket_name = S3_BUCKET_NAME - - def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]: - """Handles uploading of the file to S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - self.s3_client.upload_file(file_path, self.bucket_name, filename) - return ( - open(file_path, "rb").read(), - "s3://" + self.bucket_name + "/" + filename, - ) - except ClientError as e: - raise RuntimeError(f"Error uploading file to S3: {e}") - - def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]: - """Handles uploading of the file to local storage.""" +class LocalStorageProvider(StorageProvider): + @staticmethod + def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]: + contents = file.read() + if not contents: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) file_path = f"{UPLOAD_DIR}/{filename}" with open(file_path, "wb") as f: f.write(contents) return contents, file_path - def _get_file_from_s3(self, file_path: str) -> str: - """Handles downloading of the file from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - bucket_name, key = file_path.split("//")[1].split("/") - local_file_path = f"{UPLOAD_DIR}/{key}" - self.s3_client.download_file(bucket_name, key, local_file_path) - return local_file_path - except ClientError as e: - raise RuntimeError(f"Error downloading file from S3: {e}") - - def _get_file_from_local(self, file_path: str) -> str: + @staticmethod + def get_file(file_path: str) -> str: """Handles downloading of the file from local storage.""" return file_path - def _delete_from_s3(self, filename: str) -> None: - """Handles deletion of the file from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) - except ClientError as e: - raise RuntimeError(f"Error deleting file from S3: {e}") - - def _delete_from_local(self, filename: str) -> None: + @staticmethod + def delete_file(file_path: str) -> None: """Handles deletion of the file from local storage.""" + filename = file_path.split("/")[-1] file_path = f"{UPLOAD_DIR}/{filename}" if os.path.isfile(file_path): os.remove(file_path) else: print(f"File {file_path} not found in local storage.") - def _delete_all_from_s3(self) -> None: - """Handles deletion of all files from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) - if "Contents" in response: - for content in response["Contents"]: - self.s3_client.delete_object( - Bucket=self.bucket_name, Key=content["Key"] - ) - except ClientError as e: - raise RuntimeError(f"Error deleting all files from S3: {e}") - - def _delete_all_from_local(self) -> None: + @staticmethod + def delete_all_files() -> None: """Handles deletion of all files from local storage.""" if os.path.exists(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR): @@ -130,40 +77,70 @@ class StorageProvider: else: print(f"Directory {UPLOAD_DIR} not found in local storage.") - def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: - """Uploads a file either to S3 or the local file system.""" - contents = file.read() - if not contents: - raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) - contents, file_path = self._upload_to_local(contents, filename) - if self.storage_provider == "s3": - return self._upload_to_s3(file_path, filename) - return contents, file_path +class S3StorageProvider(StorageProvider): + def __init__(self): + self.s3_client = boto3.client( + "s3", + region_name=S3_REGION_NAME, + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + ) + self.bucket_name = S3_BUCKET_NAME + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to S3 storage.""" + _, file_path = LocalStorageProvider.upload_file(file, filename) + try: + self.s3_client.upload_file(file_path, self.bucket_name, filename) + return ( + open(file_path, "rb").read(), + "s3://" + self.bucket_name + "/" + filename, + ) + except ClientError as e: + raise RuntimeError(f"Error uploading file to S3: {e}") def get_file(self, file_path: str) -> str: - """Downloads a file either from S3 or the local file system and returns the file path.""" - if self.storage_provider == "s3": - return self._get_file_from_s3(file_path) - return self._get_file_from_local(file_path) + """Handles downloading of the file from S3 storage.""" + try: + bucket_name, key = file_path.split("//")[1].split("/") + local_file_path = f"{UPLOAD_DIR}/{key}" + self.s3_client.download_file(bucket_name, key, local_file_path) + return local_file_path + except ClientError as e: + raise RuntimeError(f"Error downloading file from S3: {e}") def delete_file(self, file_path: str) -> None: - """Deletes a file either from S3 or the local file system.""" + """Handles deletion of the file from S3 storage.""" filename = file_path.split("/")[-1] - - if self.storage_provider == "s3": - self._delete_from_s3(filename) + try: + self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) + except ClientError as e: + raise RuntimeError(f"Error deleting file from S3: {e}") # Always delete from local storage - self._delete_from_local(filename) + LocalStorageProvider.delete_file(file_path) def delete_all_files(self) -> None: - """Deletes all files from the storage.""" - if self.storage_provider == "s3": - self._delete_all_from_s3() + """Handles deletion of all files from S3 storage.""" + try: + response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) + if "Contents" in response: + for content in response["Contents"]: + self.s3_client.delete_object( + Bucket=self.bucket_name, Key=content["Key"] + ) + except ClientError as e: + raise RuntimeError(f"Error deleting all files from S3: {e}") # Always delete from local storage - self._delete_all_from_local() + LocalStorageProvider.delete_all_files() -Storage = StorageProvider(provider=STORAGE_PROVIDER) +if STORAGE_PROVIDER == "local": + Storage = LocalStorageProvider() +elif STORAGE_PROVIDER == "s3": + Storage = S3StorageProvider() +else: + raise RuntimeError(f"Unsupported storage provider: {STORAGE_PROVIDER}")