Split the storage providers into separate classes in preparation for other storage providers like GCS

This commit is contained in:
Rodrigo Agundez 2025-01-15 23:16:38 +08:00
parent 372658be6d
commit a3f737c0c6

View File

@ -1,121 +1,68 @@
import os
import shutil
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3
from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR,
)
from open_webui.constants import ERROR_MESSAGES
import boto3
from botocore.exceptions import ClientError
from typing import BinaryIO, Tuple, Optional
class StorageProvider(ABC):
@abstractmethod
def get_file(self, file_path: str) -> str:
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
pass
@abstractmethod
def delete_all_files(self) -> None:
pass
@abstractmethod
def delete_file(self, file_path: str) -> None:
pass
class StorageProvider:
def __init__(self, provider: Optional[str] = None):
self.storage_provider: str = provider or STORAGE_PROVIDER
self.s3_client = None
self.s3_bucket_name: Optional[str] = None
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
class LocalStorageProvider(StorageProvider):
@staticmethod
def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
@staticmethod
def get_file(file_path: str) -> str:
"""Handles downloading of the file from local storage."""
return file_path
def _delete_from_s3(self, filename: str) -> None:
"""Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
@staticmethod
def delete_file(file_path: str) -> None:
"""Handles deletion of the file from local storage."""
filename = file_path.split("/")[-1]
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None:
"""Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
@staticmethod
def delete_all_files() -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR):
@ -130,40 +77,70 @@ class StorageProvider:
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
contents, file_path = self._upload_to_local(contents, filename)
if self.storage_provider == "s3":
return self._upload_to_s3(file_path, filename)
return contents, file_path
class S3StorageProvider(StorageProvider):
def __init__(self):
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
"""Handles downloading of the file from S3 storage."""
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Deletes a file either from S3 or the local file system."""
"""Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1]
if self.storage_provider == "s3":
self._delete_from_s3(filename)
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
# Always delete from local storage
self._delete_from_local(filename)
LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
"""Handles deletion of all files from S3 storage."""
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
# Always delete from local storage
self._delete_all_from_local()
LocalStorageProvider.delete_all_files()
Storage = StorageProvider(provider=STORAGE_PROVIDER)
if STORAGE_PROVIDER == "local":
Storage = LocalStorageProvider()
elif STORAGE_PROVIDER == "s3":
Storage = S3StorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {STORAGE_PROVIDER}")