diff --git a/backend/open_webui/test/apps/webui/storage/test_provider.py b/backend/open_webui/test/apps/webui/storage/test_provider.py index 8d1090541..ad7a5fa2e 100644 --- a/backend/open_webui/test/apps/webui/storage/test_provider.py +++ b/backend/open_webui/test/apps/webui/storage/test_provider.py @@ -1,6 +1,9 @@ import io +import boto3 import pytest +from botocore.exceptions import ClientError +from moto import mock_aws from open_webui.storage import provider @@ -41,7 +44,8 @@ def test_class_instantiation(): provider.S3StorageProvider() -class TestLocalStorageProvider(provider.LocalStorageProvider): +class TestLocalStorageProvider: + Storage = provider.LocalStorageProvider() file_content = b"test content" file_bytesio = io.BytesIO(file_content) filename = "test.txt" @@ -50,18 +54,18 @@ class TestLocalStorageProvider(provider.LocalStorageProvider): def test_upload_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) - contents, file_path = self.upload_file(self.file_bytesio, self.filename) + contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename) assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content assert file_path == str(upload_dir / self.filename) with pytest.raises(ValueError): - self.upload_file(self.file_bytesio_empty, self.filename) + self.Storage.upload_file(self.file_bytesio_empty, self.filename) def test_get_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) file_path = str(upload_dir / self.filename) - file_path_return = self.get_file(file_path) + file_path_return = self.Storage.get_file(file_path) assert file_path == file_path_return def test_delete_file(self, monkeypatch, tmp_path): @@ -69,41 +73,101 @@ class TestLocalStorageProvider(provider.LocalStorageProvider): (upload_dir / self.filename).write_bytes(self.file_content) assert (upload_dir / self.filename).exists() file_path = str(upload_dir / self.filename) - self.delete_file(file_path) + self.Storage.delete_file(file_path) assert not (upload_dir / self.filename).exists() def test_delete_all_files(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) (upload_dir / self.filename).write_bytes(self.file_content) (upload_dir / self.filename_extra).write_bytes(self.file_content) - self.delete_all_files() + self.Storage.delete_all_files() assert not (upload_dir / self.filename).exists() assert not (upload_dir / self.filename_extra).exists() -class TestS3StorageProvider(provider.S3StorageProvider): +@mock_aws +class TestS3StorageProvider: + Storage = provider.S3StorageProvider() + Storage.bucket_name = "my-bucket" + s3_client = boto3.resource("s3", region_name="us-east-1") file_content = b"test content" - file_bytesio = io.BytesIO(file_content) filename = "test.txt" - filename_extra = "test_extra.txt" + filename_extra = "test_exyta.txt" file_bytesio_empty = io.BytesIO() - bucket_name = "my-bucket" def test_upload_file(self, monkeypatch, tmp_path): upload_dir = mock_upload_dir(monkeypatch, tmp_path) - contents, file_path = self.upload_file(self.file_bytesio, self.filename) + # S3 checks + with pytest.raises(Exception): + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename) + assert self.file_content == object.get()["Body"].read() + # local checks assert (upload_dir / self.filename).exists() assert (upload_dir / self.filename).read_bytes() == self.file_content assert contents == self.file_content - assert file_path == str(upload_dir / self.filename) + assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename with pytest.raises(ValueError): - self.upload_file(self.file_bytesio_empty, self.filename) + self.Storage.upload_file(self.file_bytesio_empty, self.filename) - def test_get_file(self): - pass + def test_get_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + file_path = self.Storage.get_file(s3_file_path) + assert file_path == str(upload_dir / self.filename) + assert (upload_dir / self.filename).exists() - def test_delete_file(self): - pass + def test_delete_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + assert (upload_dir / self.filename).exists() + self.Storage.delete_file(s3_file_path) + assert not (upload_dir / self.filename).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" - def test_delete_all_files(self): - pass + def test_delete_all_files(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + # create 2 files + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename) + assert self.file_content == object.get()["Body"].read() + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra) + assert self.file_content == object.get()["Body"].read() + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" + assert not (upload_dir / self.filename_extra).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" + + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists()