Merge pull request #8580 from rragundez/split-storage-providers

Split the storage providers into separate classes in preparation for adding more cloud providers
This commit is contained in:
Timothy Jaeryang Baek 2025-01-15 21:08:04 -08:00 committed by GitHub
commit 8b3fb2a8b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 917 additions and 583 deletions

View File

@ -9,22 +9,22 @@ from urllib.parse import urlparse
import chromadb
import requests
import yaml
from open_webui.internal.db import Base, get_db
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
DATABASE_URL,
ENV,
FRONTEND_BUILD_DIR,
OFFLINE_MODE,
OPEN_WEBUI_DIR,
WEBUI_AUTH,
WEBUI_FAVICON_URL,
WEBUI_NAME,
log,
DATABASE_URL,
OFFLINE_MODE,
)
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
from open_webui.internal.db import Base, get_db
class EndpointFilter(logging.Filter):
@ -581,7 +581,7 @@ if CUSTOM_NAME:
# STORAGE PROVIDER
####################################
STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3
STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "local") # defaults to local, s3
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)

View File

@ -1,121 +1,68 @@
import os
import shutil
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3
from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR,
)
from open_webui.constants import ERROR_MESSAGES
import boto3
from botocore.exceptions import ClientError
from typing import BinaryIO, Tuple, Optional
class StorageProvider(ABC):
@abstractmethod
def get_file(self, file_path: str) -> str:
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
pass
@abstractmethod
def delete_all_files(self) -> None:
pass
@abstractmethod
def delete_file(self, file_path: str) -> None:
pass
class StorageProvider:
def __init__(self, provider: Optional[str] = None):
self.storage_provider: str = provider or STORAGE_PROVIDER
self.s3_client = None
self.s3_bucket_name: Optional[str] = None
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
class LocalStorageProvider(StorageProvider):
@staticmethod
def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
@staticmethod
def get_file(file_path: str) -> str:
"""Handles downloading of the file from local storage."""
return file_path
def _delete_from_s3(self, filename: str) -> None:
"""Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
@staticmethod
def delete_file(file_path: str) -> None:
"""Handles deletion of the file from local storage."""
filename = file_path.split("/")[-1]
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None:
"""Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
@staticmethod
def delete_all_files() -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR):
@ -130,40 +77,75 @@ class StorageProvider:
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
contents, file_path = self._upload_to_local(contents, filename)
if self.storage_provider == "s3":
return self._upload_to_s3(file_path, filename)
return contents, file_path
class S3StorageProvider(StorageProvider):
def __init__(self):
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
"""Handles downloading of the file from S3 storage."""
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Deletes a file either from S3 or the local file system."""
"""Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1]
if self.storage_provider == "s3":
self._delete_from_s3(filename)
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
# Always delete from local storage
self._delete_from_local(filename)
LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
"""Handles deletion of all files from S3 storage."""
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
# Always delete from local storage
self._delete_all_from_local()
LocalStorageProvider.delete_all_files()
Storage = StorageProvider(provider=STORAGE_PROVIDER)
def get_storage_provider(storage_provider: str):
if storage_provider == "local":
Storage = LocalStorageProvider()
elif storage_provider == "s3":
Storage = S3StorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
return Storage
Storage = get_storage_provider(STORAGE_PROVIDER)

View File

@ -0,0 +1,173 @@
import io
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from open_webui.storage import provider
def mock_upload_dir(monkeypatch, tmp_path):
"""Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
directory = tmp_path / "uploads"
directory.mkdir()
monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
return directory
def test_imports():
provider.StorageProvider
provider.LocalStorageProvider
provider.S3StorageProvider
provider.Storage
def test_get_storage_provider():
Storage = provider.get_storage_provider("local")
assert isinstance(Storage, provider.LocalStorageProvider)
Storage = provider.get_storage_provider("s3")
assert isinstance(Storage, provider.S3StorageProvider)
with pytest.raises(RuntimeError):
provider.get_storage_provider("invalid")
def test_class_instantiation():
with pytest.raises(TypeError):
provider.StorageProvider()
with pytest.raises(TypeError):
class Test(provider.StorageProvider):
pass
Test()
provider.LocalStorageProvider()
provider.S3StorageProvider()
class TestLocalStorageProvider:
Storage = provider.LocalStorageProvider()
file_content = b"test content"
file_bytesio = io.BytesIO(file_content)
filename = "test.txt"
filename_extra = "test_exyta.txt"
file_bytesio_empty = io.BytesIO()
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
assert contents == self.file_content
assert file_path == str(upload_dir / self.filename)
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
file_path = str(upload_dir / self.filename)
file_path_return = self.Storage.get_file(file_path)
assert file_path == file_path_return
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
(upload_dir / self.filename).write_bytes(self.file_content)
assert (upload_dir / self.filename).exists()
file_path = str(upload_dir / self.filename)
self.Storage.delete_file(file_path)
assert not (upload_dir / self.filename).exists()
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
(upload_dir / self.filename).write_bytes(self.file_content)
(upload_dir / self.filename_extra).write_bytes(self.file_content)
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
@mock_aws
class TestS3StorageProvider:
Storage = provider.S3StorageProvider()
Storage.bucket_name = "my-bucket"
s3_client = boto3.resource("s3", region_name="us-east-1")
file_content = b"test content"
filename = "test.txt"
filename_extra = "test_exyta.txt"
file_bytesio_empty = io.BytesIO()
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# S3 checks
with pytest.raises(Exception):
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
assert self.file_content == object.get()["Body"].read()
# local checks
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
assert contents == self.file_content
assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
file_path = self.Storage.get_file(s3_file_path)
assert file_path == str(upload_dir / self.filename)
assert (upload_dir / self.filename).exists()
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
assert (upload_dir / self.filename).exists()
self.Storage.delete_file(s3_file_path)
assert not (upload_dir / self.filename).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# create 2 files
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
assert self.file_content == object.get()["Body"].read()
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
assert self.file_content == object.get()["Body"].read()
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
assert not (upload_dir / self.filename_extra).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()

View File

@ -99,6 +99,7 @@ dependencies = [
"docker~=7.1.0",
"pytest~=8.3.2",
"pytest-docker~=3.1.1",
"moto[s3]>=5.0.26",
"googleapis-common-protos==1.63.2",

1102
uv.lock

File diff suppressed because it is too large Load Diff