Fix wrong mimetype used for any file or folder uploaded to S3 using StorageManager

Add support for boto3's aws_session_token
This commit is contained in:
allegroai 2022-06-06 14:05:59 +03:00
parent fdefa9784c
commit 6c9e02f1b2

View File

@ -9,6 +9,7 @@ import platform
import shutil import shutil
import sys import sys
import threading import threading
import mimetypes
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -1434,6 +1435,7 @@ class _Boto3Driver(_Driver):
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries), num_download_attempts=container.config.retries),
Callback=callback, Callback=callback,
ExtraArgs={'ContentType': get_file_mimetype(object_name)}
) )
except Exception as ex: except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex) self.get_logger().error('Failed uploading: %s' % ex)
@ -1447,7 +1449,9 @@ class _Boto3Driver(_Driver):
use_threads=container.config.multipart, use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries), num_download_attempts=container.config.retries),
Callback=callback) Callback=callback,
ExtraArgs={'ContentType': get_file_mimetype(file_path)}
)
except Exception as ex: except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex) self.get_logger().error('Failed uploading: %s' % ex)
return False return False
@ -1539,13 +1543,12 @@ class _Boto3Driver(_Driver):
'time': datetime.utcnow().isoformat() 'time': datetime.utcnow().isoformat()
} }
boto_session = boto3.Session(conf.key, conf.secret) boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
boto_resource = boto_session.resource('s3', conf.region) boto_resource = boto_session.resource('s3', conf.region)
bucket = boto_resource.Bucket(bucket_name) bucket = boto_resource.Bucket(bucket_name)
bucket.put_object(Key=filename, Body=six.b(json.dumps(data))) bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))
region = cls._get_bucket_region(conf=conf, log=log, report_info=True) region = cls._get_bucket_region(conf=conf, log=log, report_info=True)
if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')): if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')):
msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region) msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region)
else: else:
@ -1594,7 +1597,7 @@ class _Boto3Driver(_Driver):
cls._bucket_location_failure_reported.add(conf.get_bucket_host()) cls._bucket_location_failure_reported.add(conf.get_bucket_host())
try: try:
boto_session = boto3.Session(conf.key, conf.secret) boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
boto_resource = boto_session.resource('s3') boto_resource = boto_session.resource('s3')
return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"] return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"]
@ -1803,7 +1806,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
) )
def create_blob_from_data( def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None self, container_name, object_name, blob_name, data, max_connections=2,
progress_callback=None, content_settings=None
): ):
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_bytes( self.__blob_service.create_blob_from_bytes(
@ -1815,7 +1819,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
) )
else: else:
client = self.__blob_service.get_blob_client(container_name, blob_name) client = self.__blob_service.get_blob_client(container_name, blob_name)
client.upload_blob(data, overwrite=True, max_concurrency=max_connections) client.upload_blob(
data, overwrite=True,
max_concurrency=max_connections,
content_settings=content_settings,
)
def create_blob_from_path( def create_blob_from_path(
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
@ -1831,7 +1839,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
) )
else: else:
self.create_blob_from_data( self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"), max_connections=max_connections container_name, None, blob_name, open(path, "rb"),
max_connections=max_connections,
content_settings=content_settings,
) )
def delete_blob(self, container_name, blob_name): def delete_blob(self, container_name, blob_name):
@ -1949,14 +1959,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
stream = None stream = None
try: try:
from azure.storage.blob import ContentSettings # noqa from azure.storage.blob import ContentSettings # noqa
from mimetypes import guess_type
container.create_blob_from_path( container.create_blob_from_path(
container.name, container.name,
blob_name, blob_name,
file_path, file_path,
max_connections=2, max_connections=2,
content_settings=ContentSettings(content_type=guess_type(file_path)), content_settings=ContentSettings(content_type=get_file_mimetype(file_path)),
progress_callback=callback, progress_callback=callback,
) )
return True return True
@ -2617,6 +2626,27 @@ class _FileStorageDriver(_Driver):
return True return True
def get_file_mimetype(file_path):
"""
Get MIME types of a file
:param file_path: Path of the local file
:type file_path: str
:return: File MIME type. Return None if failed to get it
:rtype: str
"""
# noinspection PyBroadException
try:
file_path = Path(file_path).resolve()
mimetype, _ = mimetypes.guess_type(file_path)
if mimetype:
return mimetype
except Exception:
return None
return 'binary/octet-stream'
driver_schemes = set( driver_schemes = set(
filter( filter(
None, None,