Fix wrong mimetype used for any file or folder uploaded to S3 using StorageManager

Add support for boto3's aws_session_token
This commit is contained in:
allegroai 2022-06-06 14:05:59 +03:00
parent fdefa9784c
commit 6c9e02f1b2

View File

@ -9,6 +9,7 @@ import platform
import shutil
import sys
import threading
import mimetypes
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
@ -1434,6 +1435,7 @@ class _Boto3Driver(_Driver):
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries),
Callback=callback,
ExtraArgs={'ContentType': get_file_mimetype(object_name)}
)
except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex)
@ -1447,7 +1449,9 @@ class _Boto3Driver(_Driver):
use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries),
Callback=callback)
Callback=callback,
ExtraArgs={'ContentType': get_file_mimetype(file_path)}
)
except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex)
return False
@ -1539,13 +1543,12 @@ class _Boto3Driver(_Driver):
'time': datetime.utcnow().isoformat()
}
boto_session = boto3.Session(conf.key, conf.secret)
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
boto_resource = boto_session.resource('s3', conf.region)
bucket = boto_resource.Bucket(bucket_name)
bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))
region = cls._get_bucket_region(conf=conf, log=log, report_info=True)
if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')):
msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region)
else:
@ -1594,7 +1597,7 @@ class _Boto3Driver(_Driver):
cls._bucket_location_failure_reported.add(conf.get_bucket_host())
try:
boto_session = boto3.Session(conf.key, conf.secret)
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
boto_resource = boto_session.resource('s3')
return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"]
@ -1803,7 +1806,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
)
def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None
self, container_name, object_name, blob_name, data, max_connections=2,
progress_callback=None, content_settings=None
):
if self.__legacy:
self.__blob_service.create_blob_from_bytes(
@ -1815,7 +1819,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
client.upload_blob(data, overwrite=True, max_concurrency=max_connections)
client.upload_blob(
data, overwrite=True,
max_concurrency=max_connections,
content_settings=content_settings,
)
def create_blob_from_path(
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
@ -1831,7 +1839,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
)
else:
self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"), max_connections=max_connections
container_name, None, blob_name, open(path, "rb"),
max_connections=max_connections,
content_settings=content_settings,
)
def delete_blob(self, container_name, blob_name):
@ -1949,14 +1959,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
stream = None
try:
from azure.storage.blob import ContentSettings # noqa
from mimetypes import guess_type
container.create_blob_from_path(
container.name,
blob_name,
file_path,
max_connections=2,
content_settings=ContentSettings(content_type=guess_type(file_path)),
content_settings=ContentSettings(content_type=get_file_mimetype(file_path)),
progress_callback=callback,
)
return True
@ -2617,6 +2626,27 @@ class _FileStorageDriver(_Driver):
return True
def get_file_mimetype(file_path):
"""
Get MIME types of a file
:param file_path: Path of the local file
:type file_path: str
:return: File MIME type. Return None if failed to get it
:rtype: str
"""
# noinspection PyBroadException
try:
file_path = Path(file_path).resolve()
mimetype, _ = mimetypes.guess_type(file_path)
if mimetype:
return mimetype
except Exception:
return None
return 'binary/octet-stream'
driver_schemes = set(
filter(
None,