From 6c9e02f1b2961df75163fdfd778668efa579910c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jun 2022 14:05:59 +0300 Subject: [PATCH] Fix wrong mimetype used for any file or folder uploaded to S3 using StorageManager Add support for boto3's aws_session_token --- clearml/storage/helper.py | 48 +++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/clearml/storage/helper.py b/clearml/storage/helper.py index bf86a207..8f55aba6 100644 --- a/clearml/storage/helper.py +++ b/clearml/storage/helper.py @@ -9,6 +9,7 @@ import platform import shutil import sys import threading +import mimetypes from abc import ABCMeta, abstractmethod from collections import namedtuple from concurrent.futures import ThreadPoolExecutor @@ -1434,6 +1435,7 @@ class _Boto3Driver(_Driver): max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, num_download_attempts=container.config.retries), Callback=callback, + ExtraArgs={'ContentType': get_file_mimetype(object_name)} ) except Exception as ex: self.get_logger().error('Failed uploading: %s' % ex) @@ -1447,7 +1449,9 @@ class _Boto3Driver(_Driver): use_threads=container.config.multipart, max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, num_download_attempts=container.config.retries), - Callback=callback) + Callback=callback, + ExtraArgs={'ContentType': get_file_mimetype(file_path)} + ) except Exception as ex: self.get_logger().error('Failed uploading: %s' % ex) return False @@ -1539,13 +1543,12 @@ class _Boto3Driver(_Driver): 'time': datetime.utcnow().isoformat() } - boto_session = boto3.Session(conf.key, conf.secret) + boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token) boto_resource = boto_session.resource('s3', conf.region) bucket = boto_resource.Bucket(bucket_name) bucket.put_object(Key=filename, Body=six.b(json.dumps(data))) region = cls._get_bucket_region(conf=conf, log=log, report_info=True) - if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')): msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region) else: @@ -1594,7 +1597,7 @@ class _Boto3Driver(_Driver): cls._bucket_location_failure_reported.add(conf.get_bucket_host()) try: - boto_session = boto3.Session(conf.key, conf.secret) + boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token) boto_resource = boto_session.resource('s3') return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"] @@ -1803,7 +1806,8 @@ class _AzureBlobServiceStorageDriver(_Driver): ) def create_blob_from_data( - self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None + self, container_name, object_name, blob_name, data, max_connections=2, + progress_callback=None, content_settings=None ): if self.__legacy: self.__blob_service.create_blob_from_bytes( @@ -1815,7 +1819,11 @@ class _AzureBlobServiceStorageDriver(_Driver): ) else: client = self.__blob_service.get_blob_client(container_name, blob_name) - client.upload_blob(data, overwrite=True, max_concurrency=max_connections) + client.upload_blob( + data, overwrite=True, + max_concurrency=max_connections, + content_settings=content_settings, + ) def create_blob_from_path( self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None @@ -1831,7 +1839,9 @@ class _AzureBlobServiceStorageDriver(_Driver): ) else: self.create_blob_from_data( - container_name, None, blob_name, open(path, "rb"), max_connections=max_connections + container_name, None, blob_name, open(path, "rb"), + max_connections=max_connections, + content_settings=content_settings, ) def delete_blob(self, container_name, blob_name): @@ -1949,14 +1959,13 @@ class _AzureBlobServiceStorageDriver(_Driver): stream = None try: from azure.storage.blob import ContentSettings # noqa - from mimetypes import guess_type container.create_blob_from_path( container.name, blob_name, file_path, max_connections=2, - content_settings=ContentSettings(content_type=guess_type(file_path)), + content_settings=ContentSettings(content_type=get_file_mimetype(file_path)), progress_callback=callback, ) return True @@ -2617,6 +2626,27 @@ class _FileStorageDriver(_Driver): return True +def get_file_mimetype(file_path): + """ + Get MIME types of a file + + :param file_path: Path of the local file + :type file_path: str + + :return: File MIME type. Return None if failed to get it + :rtype: str + """ + # noinspection PyBroadException + try: + file_path = Path(file_path).resolve() + mimetype, _ = mimetypes.guess_type(file_path) + if mimetype: + return mimetype + except Exception: + return None + return 'binary/octet-stream' + + driver_schemes = set( filter( None,