diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 2fed66798..779e3d3bc 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -533,6 +533,18 @@ if CUSTOM_NAME: pass +#################################### +# STORAGE PROVIDER +#################################### + +STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3 + +S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) +S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) +S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) +S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) +S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) + #################################### # File Upload DIR #################################### diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py new file mode 100644 index 000000000..6cb5b5d95 --- /dev/null +++ b/backend/open_webui/storage/provider.py @@ -0,0 +1,117 @@ +import os +import boto3 +from botocore.exceptions import ClientError +from open_webui.config import ( + STORAGE_PROVIDER, + S3_ACCESS_KEY_ID, + S3_SECRET_ACCESS_KEY, + S3_BUCKET_NAME, + S3_REGION_NAME, + S3_ENDPOINT_URL, + UPLOAD_DIR, + AppConfig, +) + + +class StorageProvider: + def __init__(self): + self.storage_provider = None + self.client = None + self.bucket_name = None + + if STORAGE_PROVIDER == "s3": + self.storage_provider = "s3" + self.client = boto3.client( + "s3", + region_name=S3_REGION_NAME, + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + ) + self.bucket_name = S3_BUCKET_NAME + else: + self.storage_provider = "local" + + def get_storage_provider(self): + return self.storage_provider + + def upload_file(self, file, filename): + if self.storage_provider == "s3": + try: + bucket = self.bucket_name + self.client.upload_fileobj(file, bucket, filename) + return filename + except ClientError as e: + raise RuntimeError(f"Error uploading file: {e}") + else: + file_path = os.path.join(UPLOAD_DIR, filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as f: + f.write(file.read()) + return filename + + def list_files(self): + if self.storage_provider == "s3": + try: + bucket = self.bucket_name + response = self.client.list_objects_v2(Bucket=bucket) + if "Contents" in response: + return [content["Key"] for content in response["Contents"]] + return [] + except ClientError as e: + raise RuntimeError(f"Error listing files: {e}") + else: + return [ + f + for f in os.listdir(UPLOAD_DIR) + if os.path.isfile(os.path.join(UPLOAD_DIR, f)) + ] + + def get_file(self, filename): + if self.storage_provider == "s3": + try: + bucket = self.bucket_name + file_path = f"/tmp/{filename}" + self.client.download_file(bucket, filename, file_path) + return file_path + except ClientError as e: + raise RuntimeError(f"Error downloading file: {e}") + else: + file_path = os.path.join(UPLOAD_DIR, filename) + if os.path.isfile(file_path): + return file_path + else: + raise FileNotFoundError(f"File {filename} not found in local storage.") + + def delete_file(self, filename): + if self.storage_provider == "s3": + try: + bucket = self.bucket_name + self.client.delete_object(Bucket=bucket, Key=filename) + except ClientError as e: + raise RuntimeError(f"Error deleting file: {e}") + else: + file_path = os.path.join(UPLOAD_DIR, filename) + if os.path.isfile(file_path): + os.remove(file_path) + else: + raise FileNotFoundError(f"File {filename} not found in local storage.") + + def delete_all_files(self): + if self.storage_provider == "s3": + try: + bucket = self.bucket_name + response = self.client.list_objects_v2(Bucket=bucket) + if "Contents" in response: + for content in response["Contents"]: + self.client.delete_object(Bucket=bucket, Key=content["Key"]) + except ClientError as e: + raise RuntimeError(f"Error deleting all files: {e}") + else: + for filename in os.listdir(UPLOAD_DIR): + file_path = os.path.join(UPLOAD_DIR, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + +Storage = StorageProvider()