import os import shutil import json from abc import ABC, abstractmethod from typing import BinaryIO, Tuple import boto3 from botocore.exceptions import ClientError from open_webui.config import ( S3_ACCESS_KEY_ID, S3_BUCKET_NAME, S3_ENDPOINT_URL, S3_KEY_PREFIX, S3_REGION_NAME, S3_SECRET_ACCESS_KEY, GCS_BUCKET_NAME, GOOGLE_APPLICATION_CREDENTIALS_JSON, AZURE_STORAGE_ENDPOINT, AZURE_STORAGE_CONTAINER_NAME, AZURE_STORAGE_KEY, STORAGE_PROVIDER, UPLOAD_DIR, ) from google.cloud import storage from google.cloud.exceptions import GoogleCloudError, NotFound from open_webui.constants import ERROR_MESSAGES from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient from azure.core.exceptions import ResourceNotFoundError class StorageProvider(ABC): @abstractmethod def get_file(self, file_path: str) -> str: pass @abstractmethod def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: pass @abstractmethod def delete_all_files(self) -> None: pass @abstractmethod def delete_file(self, file_path: str) -> None: pass class LocalStorageProvider(StorageProvider): @staticmethod def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]: contents = file.read() if not contents: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) file_path = f"{UPLOAD_DIR}/{filename}" with open(file_path, "wb") as f: f.write(contents) return contents, file_path @staticmethod def get_file(file_path: str) -> str: """Handles downloading of the file from local storage.""" return file_path @staticmethod def delete_file(file_path: str) -> None: """Handles deletion of the file from local storage.""" filename = file_path.split("/")[-1] file_path = f"{UPLOAD_DIR}/{filename}" if os.path.isfile(file_path): os.remove(file_path) else: print(f"File {file_path} not found in local storage.") @staticmethod def delete_all_files() -> None: """Handles deletion of all files from local storage.""" if os.path.exists(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR): file_path = os.path.join(UPLOAD_DIR, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove the file or link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") else: print(f"Directory {UPLOAD_DIR} not found in local storage.") class S3StorageProvider(StorageProvider): def __init__(self): self.s3_client = boto3.client( "s3", region_name=S3_REGION_NAME, endpoint_url=S3_ENDPOINT_URL, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY, ) self.bucket_name = S3_BUCKET_NAME self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: """Handles uploading of the file to S3 storage.""" _, file_path = LocalStorageProvider.upload_file(file, filename) try: s3_key = os.path.join(self.key_prefix, filename) self.s3_client.upload_file(file_path, self.bucket_name, s3_key) return ( open(file_path, "rb").read(), "s3://" + self.bucket_name + "/" + s3_key, ) except ClientError as e: raise RuntimeError(f"Error uploading file to S3: {e}") def get_file(self, file_path: str) -> str: """Handles downloading of the file from S3 storage.""" try: s3_key = self._extract_s3_key(file_path) local_file_path = self._get_local_file_path(s3_key) self.s3_client.download_file(self.bucket_name, s3_key, local_file_path) return local_file_path except ClientError as e: raise RuntimeError(f"Error downloading file from S3: {e}") def delete_file(self, file_path: str) -> None: """Handles deletion of the file from S3 storage.""" try: s3_key = self._extract_s3_key(file_path) self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key) except ClientError as e: raise RuntimeError(f"Error deleting file from S3: {e}") # Always delete from local storage LocalStorageProvider.delete_file(file_path) def delete_all_files(self) -> None: """Handles deletion of all files from S3 storage.""" try: response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) if "Contents" in response: for content in response["Contents"]: # Skip objects that were not uploaded from open-webui in the first place if not content["Key"].startswith(self.key_prefix): continue self.s3_client.delete_object( Bucket=self.bucket_name, Key=content["Key"] ) except ClientError as e: raise RuntimeError(f"Error deleting all files from S3: {e}") # Always delete from local storage LocalStorageProvider.delete_all_files() # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name. def _extract_s3_key(self, full_file_path: str) -> str: return "/".join(full_file_path.split("//")[1].split("/")[1:]) def _get_local_file_path(self, s3_key: str) -> str: return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}" class GCSStorageProvider(StorageProvider): def __init__(self): self.bucket_name = GCS_BUCKET_NAME if GOOGLE_APPLICATION_CREDENTIALS_JSON: self.gcs_client = storage.Client.from_service_account_info( info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) ) else: # if no credentials json is provided, credentials will be picked up from the environment # if running on local environment, credentials would be user credentials # if running on a Compute Engine instance, credentials would be from Google Metadata server self.gcs_client = storage.Client() self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME) def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: """Handles uploading of the file to GCS storage.""" contents, file_path = LocalStorageProvider.upload_file(file, filename) try: blob = self.bucket.blob(filename) blob.upload_from_filename(file_path) return contents, "gs://" + self.bucket_name + "/" + filename except GoogleCloudError as e: raise RuntimeError(f"Error uploading file to GCS: {e}") def get_file(self, file_path: str) -> str: """Handles downloading of the file from GCS storage.""" try: filename = file_path.removeprefix("gs://").split("/")[1] local_file_path = f"{UPLOAD_DIR}/{filename}" blob = self.bucket.get_blob(filename) blob.download_to_filename(local_file_path) return local_file_path except NotFound as e: raise RuntimeError(f"Error downloading file from GCS: {e}") def delete_file(self, file_path: str) -> None: """Handles deletion of the file from GCS storage.""" try: filename = file_path.removeprefix("gs://").split("/")[1] blob = self.bucket.get_blob(filename) blob.delete() except NotFound as e: raise RuntimeError(f"Error deleting file from GCS: {e}") # Always delete from local storage LocalStorageProvider.delete_file(file_path) def delete_all_files(self) -> None: """Handles deletion of all files from GCS storage.""" try: blobs = self.bucket.list_blobs() for blob in blobs: blob.delete() except NotFound as e: raise RuntimeError(f"Error deleting all files from GCS: {e}") # Always delete from local storage LocalStorageProvider.delete_all_files() class AzureStorageProvider(StorageProvider): def __init__(self): self.endpoint = AZURE_STORAGE_ENDPOINT self.container_name = AZURE_STORAGE_CONTAINER_NAME storage_key = AZURE_STORAGE_KEY if storage_key: # Configure using the Azure Storage Account Endpoint and Key self.blob_service_client = BlobServiceClient( account_url=self.endpoint, credential=storage_key ) else: # Configure using the Azure Storage Account Endpoint and DefaultAzureCredential # If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication self.blob_service_client = BlobServiceClient( account_url=self.endpoint, credential=DefaultAzureCredential() ) self.container_client = self.blob_service_client.get_container_client( self.container_name ) def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: """Handles uploading of the file to Azure Blob Storage.""" contents, file_path = LocalStorageProvider.upload_file(file, filename) try: blob_client = self.container_client.get_blob_client(filename) blob_client.upload_blob(contents, overwrite=True) return contents, f"{self.endpoint}/{self.container_name}/{filename}" except Exception as e: raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}") def get_file(self, file_path: str) -> str: """Handles downloading of the file from Azure Blob Storage.""" try: filename = file_path.split("/")[-1] local_file_path = f"{UPLOAD_DIR}/{filename}" blob_client = self.container_client.get_blob_client(filename) with open(local_file_path, "wb") as download_file: download_file.write(blob_client.download_blob().readall()) return local_file_path except ResourceNotFoundError as e: raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}") def delete_file(self, file_path: str) -> None: """Handles deletion of the file from Azure Blob Storage.""" try: filename = file_path.split("/")[-1] blob_client = self.container_client.get_blob_client(filename) blob_client.delete_blob() except ResourceNotFoundError as e: raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}") # Always delete from local storage LocalStorageProvider.delete_file(file_path) def delete_all_files(self) -> None: """Handles deletion of all files from Azure Blob Storage.""" try: blobs = self.container_client.list_blobs() for blob in blobs: self.container_client.delete_blob(blob.name) except Exception as e: raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}") # Always delete from local storage LocalStorageProvider.delete_all_files() def get_storage_provider(storage_provider: str): if storage_provider == "local": Storage = LocalStorageProvider() elif storage_provider == "s3": Storage = S3StorageProvider() elif storage_provider == "gcs": Storage = GCSStorageProvider() elif storage_provider == "azure": Storage = AzureStorageProvider() else: raise RuntimeError(f"Unsupported storage provider: {storage_provider}") return Storage Storage = get_storage_provider(STORAGE_PROVIDER)