mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Fix wrong mimetype used for any file or folder uploaded to S3 using StorageManager
Add support for boto3's aws_session_token
This commit is contained in:
parent
fdefa9784c
commit
6c9e02f1b2
@ -9,6 +9,7 @@ import platform
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import mimetypes
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -1434,6 +1435,7 @@ class _Boto3Driver(_Driver):
|
||||
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
||||
num_download_attempts=container.config.retries),
|
||||
Callback=callback,
|
||||
ExtraArgs={'ContentType': get_file_mimetype(object_name)}
|
||||
)
|
||||
except Exception as ex:
|
||||
self.get_logger().error('Failed uploading: %s' % ex)
|
||||
@ -1447,7 +1449,9 @@ class _Boto3Driver(_Driver):
|
||||
use_threads=container.config.multipart,
|
||||
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
||||
num_download_attempts=container.config.retries),
|
||||
Callback=callback)
|
||||
Callback=callback,
|
||||
ExtraArgs={'ContentType': get_file_mimetype(file_path)}
|
||||
)
|
||||
except Exception as ex:
|
||||
self.get_logger().error('Failed uploading: %s' % ex)
|
||||
return False
|
||||
@ -1539,13 +1543,12 @@ class _Boto3Driver(_Driver):
|
||||
'time': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
boto_session = boto3.Session(conf.key, conf.secret)
|
||||
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
|
||||
boto_resource = boto_session.resource('s3', conf.region)
|
||||
bucket = boto_resource.Bucket(bucket_name)
|
||||
bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))
|
||||
|
||||
region = cls._get_bucket_region(conf=conf, log=log, report_info=True)
|
||||
|
||||
if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')):
|
||||
msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region)
|
||||
else:
|
||||
@ -1594,7 +1597,7 @@ class _Boto3Driver(_Driver):
|
||||
cls._bucket_location_failure_reported.add(conf.get_bucket_host())
|
||||
|
||||
try:
|
||||
boto_session = boto3.Session(conf.key, conf.secret)
|
||||
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
|
||||
boto_resource = boto_session.resource('s3')
|
||||
return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"]
|
||||
|
||||
@ -1803,7 +1806,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
)
|
||||
|
||||
def create_blob_from_data(
|
||||
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None
|
||||
self, container_name, object_name, blob_name, data, max_connections=2,
|
||||
progress_callback=None, content_settings=None
|
||||
):
|
||||
if self.__legacy:
|
||||
self.__blob_service.create_blob_from_bytes(
|
||||
@ -1815,7 +1819,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
client.upload_blob(data, overwrite=True, max_concurrency=max_connections)
|
||||
client.upload_blob(
|
||||
data, overwrite=True,
|
||||
max_concurrency=max_connections,
|
||||
content_settings=content_settings,
|
||||
)
|
||||
|
||||
def create_blob_from_path(
|
||||
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
|
||||
@ -1831,7 +1839,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
)
|
||||
else:
|
||||
self.create_blob_from_data(
|
||||
container_name, None, blob_name, open(path, "rb"), max_connections=max_connections
|
||||
container_name, None, blob_name, open(path, "rb"),
|
||||
max_connections=max_connections,
|
||||
content_settings=content_settings,
|
||||
)
|
||||
|
||||
def delete_blob(self, container_name, blob_name):
|
||||
@ -1949,14 +1959,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
stream = None
|
||||
try:
|
||||
from azure.storage.blob import ContentSettings # noqa
|
||||
from mimetypes import guess_type
|
||||
|
||||
container.create_blob_from_path(
|
||||
container.name,
|
||||
blob_name,
|
||||
file_path,
|
||||
max_connections=2,
|
||||
content_settings=ContentSettings(content_type=guess_type(file_path)),
|
||||
content_settings=ContentSettings(content_type=get_file_mimetype(file_path)),
|
||||
progress_callback=callback,
|
||||
)
|
||||
return True
|
||||
@ -2617,6 +2626,27 @@ class _FileStorageDriver(_Driver):
|
||||
return True
|
||||
|
||||
|
||||
def get_file_mimetype(file_path):
|
||||
"""
|
||||
Get MIME types of a file
|
||||
|
||||
:param file_path: Path of the local file
|
||||
:type file_path: str
|
||||
|
||||
:return: File MIME type. Return None if failed to get it
|
||||
:rtype: str
|
||||
"""
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
file_path = Path(file_path).resolve()
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
if mimetype:
|
||||
return mimetype
|
||||
except Exception:
|
||||
return None
|
||||
return 'binary/octet-stream'
|
||||
|
||||
|
||||
driver_schemes = set(
|
||||
filter(
|
||||
None,
|
||||
|
Loading…
Reference in New Issue
Block a user