mirror of
https://github.com/open-webui/open-webui
synced 2025-04-24 08:16:02 +00:00
use key_prefix in rest of S3StorageProvider
This commit is contained in:
parent
5ca6afc0fc
commit
7f82476926
@ -94,16 +94,17 @@ class S3StorageProvider(StorageProvider):
|
|||||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||||
)
|
)
|
||||||
self.bucket_name = S3_BUCKET_NAME
|
self.bucket_name = S3_BUCKET_NAME
|
||||||
|
self.key_prefix = S3_KEY_PREFIX
|
||||||
|
|
||||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||||
"""Handles uploading of the file to S3 storage."""
|
"""Handles uploading of the file to S3 storage."""
|
||||||
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||||
try:
|
try:
|
||||||
s3_key = os.path.join(S3_KEY_PREFIX, filename)
|
s3_key = os.path.join(self.key_prefix, filename)
|
||||||
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||||
return (
|
return (
|
||||||
open(file_path, "rb").read(),
|
open(file_path, "rb").read(),
|
||||||
"s3://" + self.bucket_name + "/" + filename,
|
"s3://" + self.bucket_name + "/" + s3_key,
|
||||||
)
|
)
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error uploading file to S3: {e}")
|
raise RuntimeError(f"Error uploading file to S3: {e}")
|
||||||
@ -111,18 +112,18 @@ class S3StorageProvider(StorageProvider):
|
|||||||
def get_file(self, file_path: str) -> str:
|
def get_file(self, file_path: str) -> str:
|
||||||
"""Handles downloading of the file from S3 storage."""
|
"""Handles downloading of the file from S3 storage."""
|
||||||
try:
|
try:
|
||||||
bucket_name, key = file_path.split("//")[1].split("/")
|
s3_key = self._extract_s3_key(file_path)
|
||||||
local_file_path = f"{UPLOAD_DIR}/{key}"
|
local_file_path = self._get_local_file_path(s3_key)
|
||||||
self.s3_client.download_file(bucket_name, key, local_file_path)
|
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
|
||||||
return local_file_path
|
return local_file_path
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error downloading file from S3: {e}")
|
raise RuntimeError(f"Error downloading file from S3: {e}")
|
||||||
|
|
||||||
def delete_file(self, file_path: str) -> None:
|
def delete_file(self, file_path: str) -> None:
|
||||||
"""Handles deletion of the file from S3 storage."""
|
"""Handles deletion of the file from S3 storage."""
|
||||||
filename = file_path.split("/")[-1]
|
|
||||||
try:
|
try:
|
||||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
|
s3_key = self._extract_s3_key(file_path)
|
||||||
|
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Error deleting file from S3: {e}")
|
raise RuntimeError(f"Error deleting file from S3: {e}")
|
||||||
|
|
||||||
@ -135,6 +136,9 @@ class S3StorageProvider(StorageProvider):
|
|||||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
||||||
if "Contents" in response:
|
if "Contents" in response:
|
||||||
for content in response["Contents"]:
|
for content in response["Contents"]:
|
||||||
|
# Skip objects that were not uploaded from open-webui in the first place
|
||||||
|
if not content["Key"].startswith(self.key_prefix): continue
|
||||||
|
|
||||||
self.s3_client.delete_object(
|
self.s3_client.delete_object(
|
||||||
Bucket=self.bucket_name, Key=content["Key"]
|
Bucket=self.bucket_name, Key=content["Key"]
|
||||||
)
|
)
|
||||||
@ -144,6 +148,12 @@ class S3StorageProvider(StorageProvider):
|
|||||||
# Always delete from local storage
|
# Always delete from local storage
|
||||||
LocalStorageProvider.delete_all_files()
|
LocalStorageProvider.delete_all_files()
|
||||||
|
|
||||||
|
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
|
||||||
|
def _extract_s3_key(self, full_file_path: str) -> str:
|
||||||
|
return ''.join(full_file_path.split("//")[1].split("/")[1:])
|
||||||
|
|
||||||
|
def _get_local_file_path(self, s3_key: str) -> str:
|
||||||
|
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
|
||||||
|
|
||||||
class GCSStorageProvider(StorageProvider):
|
class GCSStorageProvider(StorageProvider):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user