Split the storage providers into separate classes in preparation for other storage providers like GCS

This commit is contained in:
Rodrigo Agundez 2025-01-15 23:16:38 +08:00
parent 372658be6d
commit a3f737c0c6

View File

@ -1,121 +1,68 @@
import os import os
import shutil
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ( from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID, S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME, S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL, S3_ENDPOINT_URL,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR, UPLOAD_DIR,
) )
from open_webui.constants import ERROR_MESSAGES
import boto3 class StorageProvider(ABC):
from botocore.exceptions import ClientError @abstractmethod
from typing import BinaryIO, Tuple, Optional def get_file(self, file_path: str) -> str:
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
pass
@abstractmethod
def delete_all_files(self) -> None:
pass
@abstractmethod
def delete_file(self, file_path: str) -> None:
pass
class StorageProvider: class LocalStorageProvider(StorageProvider):
def __init__(self, provider: Optional[str] = None): @staticmethod
self.storage_provider: str = provider or STORAGE_PROVIDER def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
contents = file.read()
self.s3_client = None if not contents:
self.s3_bucket_name: Optional[str] = None raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(contents) f.write(contents)
return contents, file_path return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str: @staticmethod
"""Handles downloading of the file from S3 storage.""" def get_file(file_path: str) -> str:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
"""Handles downloading of the file from local storage.""" """Handles downloading of the file from local storage."""
return file_path return file_path
def _delete_from_s3(self, filename: str) -> None: @staticmethod
"""Handles deletion of the file from S3 storage.""" def delete_file(file_path: str) -> None:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage.""" """Handles deletion of the file from local storage."""
filename = file_path.split("/")[-1]
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path): if os.path.isfile(file_path):
os.remove(file_path) os.remove(file_path)
else: else:
print(f"File {file_path} not found in local storage.") print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None: @staticmethod
"""Handles deletion of all files from S3 storage.""" def delete_all_files() -> None:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage.""" """Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR): if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR):
@ -130,40 +77,70 @@ class StorageProvider:
else: else:
print(f"Directory {UPLOAD_DIR} not found in local storage.") print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
contents, file_path = self._upload_to_local(contents, filename)
if self.storage_provider == "s3": class S3StorageProvider(StorageProvider):
return self._upload_to_s3(file_path, filename) def __init__(self):
return contents, file_path self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str: def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path.""" """Handles downloading of the file from S3 storage."""
if self.storage_provider == "s3": try:
return self._get_file_from_s3(file_path) bucket_name, key = file_path.split("//")[1].split("/")
return self._get_file_from_local(file_path) local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None: def delete_file(self, file_path: str) -> None:
"""Deletes a file either from S3 or the local file system.""" """Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1] filename = file_path.split("/")[-1]
try:
if self.storage_provider == "s3": self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
self._delete_from_s3(filename) except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
# Always delete from local storage # Always delete from local storage
self._delete_from_local(filename) LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None: def delete_all_files(self) -> None:
"""Deletes all files from the storage.""" """Handles deletion of all files from S3 storage."""
if self.storage_provider == "s3": try:
self._delete_all_from_s3() response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
# Always delete from local storage # Always delete from local storage
self._delete_all_from_local() LocalStorageProvider.delete_all_files()
Storage = StorageProvider(provider=STORAGE_PROVIDER) if STORAGE_PROVIDER == "local":
Storage = LocalStorageProvider()
elif STORAGE_PROVIDER == "s3":
Storage = S3StorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {STORAGE_PROVIDER}")