Merge pull request #8580 from rragundez/split-storage-providers

Split the storage providers into separate classes in preparation for adding more cloud providers
This commit is contained in:
Timothy Jaeryang Baek 2025-01-15 21:08:04 -08:00 committed by GitHub
commit 8b3fb2a8b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 917 additions and 583 deletions

View File

@ -9,22 +9,22 @@ from urllib.parse import urlparse
import chromadb import chromadb
import requests import requests
import yaml from pydantic import BaseModel
from open_webui.internal.db import Base, get_db from sqlalchemy import JSON, Column, DateTime, Integer, func
from open_webui.env import ( from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR, DATA_DIR,
DATABASE_URL,
ENV, ENV,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
OFFLINE_MODE,
OPEN_WEBUI_DIR,
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_FAVICON_URL, WEBUI_FAVICON_URL,
WEBUI_NAME, WEBUI_NAME,
log, log,
DATABASE_URL,
OFFLINE_MODE,
) )
from pydantic import BaseModel from open_webui.internal.db import Base, get_db
from sqlalchemy import JSON, Column, DateTime, Integer, func
class EndpointFilter(logging.Filter): class EndpointFilter(logging.Filter):
@ -581,7 +581,7 @@ if CUSTOM_NAME:
# STORAGE PROVIDER # STORAGE PROVIDER
#################################### ####################################
STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3 STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "local") # defaults to local, s3
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)

View File

@ -1,121 +1,68 @@
import os import os
import shutil
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ( from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID, S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME, S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL, S3_ENDPOINT_URL,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR, UPLOAD_DIR,
) )
from open_webui.constants import ERROR_MESSAGES
import boto3 class StorageProvider(ABC):
from botocore.exceptions import ClientError @abstractmethod
from typing import BinaryIO, Tuple, Optional def get_file(self, file_path: str) -> str:
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
pass
@abstractmethod
def delete_all_files(self) -> None:
pass
@abstractmethod
def delete_file(self, file_path: str) -> None:
pass
class StorageProvider: class LocalStorageProvider(StorageProvider):
def __init__(self, provider: Optional[str] = None): @staticmethod
self.storage_provider: str = provider or STORAGE_PROVIDER def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
contents = file.read()
self.s3_client = None if not contents:
self.s3_bucket_name: Optional[str] = None raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(contents) f.write(contents)
return contents, file_path return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str: @staticmethod
"""Handles downloading of the file from S3 storage.""" def get_file(file_path: str) -> str:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
"""Handles downloading of the file from local storage.""" """Handles downloading of the file from local storage."""
return file_path return file_path
def _delete_from_s3(self, filename: str) -> None: @staticmethod
"""Handles deletion of the file from S3 storage.""" def delete_file(file_path: str) -> None:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage.""" """Handles deletion of the file from local storage."""
filename = file_path.split("/")[-1]
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path): if os.path.isfile(file_path):
os.remove(file_path) os.remove(file_path)
else: else:
print(f"File {file_path} not found in local storage.") print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None: @staticmethod
"""Handles deletion of all files from S3 storage.""" def delete_all_files() -> None:
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage.""" """Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR): if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR):
@ -130,40 +77,75 @@ class StorageProvider:
else: else:
print(f"Directory {UPLOAD_DIR} not found in local storage.") print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
contents, file_path = self._upload_to_local(contents, filename)
if self.storage_provider == "s3": class S3StorageProvider(StorageProvider):
return self._upload_to_s3(file_path, filename) def __init__(self):
return contents, file_path self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str: def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path.""" """Handles downloading of the file from S3 storage."""
if self.storage_provider == "s3": try:
return self._get_file_from_s3(file_path) bucket_name, key = file_path.split("//")[1].split("/")
return self._get_file_from_local(file_path) local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None: def delete_file(self, file_path: str) -> None:
"""Deletes a file either from S3 or the local file system.""" """Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1] filename = file_path.split("/")[-1]
try:
if self.storage_provider == "s3": self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
self._delete_from_s3(filename) except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
# Always delete from local storage # Always delete from local storage
self._delete_from_local(filename) LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None: def delete_all_files(self) -> None:
"""Deletes all files from the storage.""" """Handles deletion of all files from S3 storage."""
if self.storage_provider == "s3": try:
self._delete_all_from_s3() response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
# Always delete from local storage # Always delete from local storage
self._delete_all_from_local() LocalStorageProvider.delete_all_files()
Storage = StorageProvider(provider=STORAGE_PROVIDER) def get_storage_provider(storage_provider: str):
if storage_provider == "local":
Storage = LocalStorageProvider()
elif storage_provider == "s3":
Storage = S3StorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
return Storage
Storage = get_storage_provider(STORAGE_PROVIDER)

View File

@ -0,0 +1,173 @@
import io
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from open_webui.storage import provider
def mock_upload_dir(monkeypatch, tmp_path):
"""Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
directory = tmp_path / "uploads"
directory.mkdir()
monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
return directory
def test_imports():
provider.StorageProvider
provider.LocalStorageProvider
provider.S3StorageProvider
provider.Storage
def test_get_storage_provider():
Storage = provider.get_storage_provider("local")
assert isinstance(Storage, provider.LocalStorageProvider)
Storage = provider.get_storage_provider("s3")
assert isinstance(Storage, provider.S3StorageProvider)
with pytest.raises(RuntimeError):
provider.get_storage_provider("invalid")
def test_class_instantiation():
with pytest.raises(TypeError):
provider.StorageProvider()
with pytest.raises(TypeError):
class Test(provider.StorageProvider):
pass
Test()
provider.LocalStorageProvider()
provider.S3StorageProvider()
class TestLocalStorageProvider:
Storage = provider.LocalStorageProvider()
file_content = b"test content"
file_bytesio = io.BytesIO(file_content)
filename = "test.txt"
filename_extra = "test_exyta.txt"
file_bytesio_empty = io.BytesIO()
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
assert contents == self.file_content
assert file_path == str(upload_dir / self.filename)
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
file_path = str(upload_dir / self.filename)
file_path_return = self.Storage.get_file(file_path)
assert file_path == file_path_return
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
(upload_dir / self.filename).write_bytes(self.file_content)
assert (upload_dir / self.filename).exists()
file_path = str(upload_dir / self.filename)
self.Storage.delete_file(file_path)
assert not (upload_dir / self.filename).exists()
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
(upload_dir / self.filename).write_bytes(self.file_content)
(upload_dir / self.filename_extra).write_bytes(self.file_content)
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
@mock_aws
class TestS3StorageProvider:
Storage = provider.S3StorageProvider()
Storage.bucket_name = "my-bucket"
s3_client = boto3.resource("s3", region_name="us-east-1")
file_content = b"test content"
filename = "test.txt"
filename_extra = "test_exyta.txt"
file_bytesio_empty = io.BytesIO()
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# S3 checks
with pytest.raises(Exception):
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
assert self.file_content == object.get()["Body"].read()
# local checks
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
assert contents == self.file_content
assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
file_path = self.Storage.get_file(s3_file_path)
assert file_path == str(upload_dir / self.filename)
assert (upload_dir / self.filename).exists()
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
contents, s3_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
assert (upload_dir / self.filename).exists()
self.Storage.delete_file(s3_file_path)
assert not (upload_dir / self.filename).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# create 2 files
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
assert self.file_content == object.get()["Body"].read()
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
assert self.file_content == object.get()["Body"].read()
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
assert not (upload_dir / self.filename_extra).exists()
with pytest.raises(ClientError) as exc:
self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
error = exc.value.response["Error"]
assert error["Code"] == "404"
assert error["Message"] == "Not Found"
self.Storage.delete_all_files()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()

View File

@ -99,6 +99,7 @@ dependencies = [
"docker~=7.1.0", "docker~=7.1.0",
"pytest~=8.3.2", "pytest~=8.3.2",
"pytest-docker~=3.1.1", "pytest-docker~=3.1.1",
"moto[s3]>=5.0.26",
"googleapis-common-protos==1.63.2", "googleapis-common-protos==1.63.2",

1102
uv.lock

File diff suppressed because it is too large Load Diff