mirror of
https://github.com/clearml/clearml
synced 2025-03-04 02:57:24 +00:00
Fix wrong mimetype used for any file or folder uploaded to S3 using StorageManager
Add support for boto3's aws_session_token
This commit is contained in:
parent
fdefa9784c
commit
6c9e02f1b2
@ -9,6 +9,7 @@ import platform
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import mimetypes
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@ -1434,6 +1435,7 @@ class _Boto3Driver(_Driver):
|
|||||||
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
||||||
num_download_attempts=container.config.retries),
|
num_download_attempts=container.config.retries),
|
||||||
Callback=callback,
|
Callback=callback,
|
||||||
|
ExtraArgs={'ContentType': get_file_mimetype(object_name)}
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
self.get_logger().error('Failed uploading: %s' % ex)
|
self.get_logger().error('Failed uploading: %s' % ex)
|
||||||
@ -1447,7 +1449,9 @@ class _Boto3Driver(_Driver):
|
|||||||
use_threads=container.config.multipart,
|
use_threads=container.config.multipart,
|
||||||
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
|
||||||
num_download_attempts=container.config.retries),
|
num_download_attempts=container.config.retries),
|
||||||
Callback=callback)
|
Callback=callback,
|
||||||
|
ExtraArgs={'ContentType': get_file_mimetype(file_path)}
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
self.get_logger().error('Failed uploading: %s' % ex)
|
self.get_logger().error('Failed uploading: %s' % ex)
|
||||||
return False
|
return False
|
||||||
@ -1539,13 +1543,12 @@ class _Boto3Driver(_Driver):
|
|||||||
'time': datetime.utcnow().isoformat()
|
'time': datetime.utcnow().isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
boto_session = boto3.Session(conf.key, conf.secret)
|
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
|
||||||
boto_resource = boto_session.resource('s3', conf.region)
|
boto_resource = boto_session.resource('s3', conf.region)
|
||||||
bucket = boto_resource.Bucket(bucket_name)
|
bucket = boto_resource.Bucket(bucket_name)
|
||||||
bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))
|
bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))
|
||||||
|
|
||||||
region = cls._get_bucket_region(conf=conf, log=log, report_info=True)
|
region = cls._get_bucket_region(conf=conf, log=log, report_info=True)
|
||||||
|
|
||||||
if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')):
|
if region and ((conf.region and region != conf.region) or (not conf.region and region != 'us-east-1')):
|
||||||
msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region)
|
msg = "incorrect region specified for bucket %s (detected region %s)" % (conf.bucket, region)
|
||||||
else:
|
else:
|
||||||
@ -1594,7 +1597,7 @@ class _Boto3Driver(_Driver):
|
|||||||
cls._bucket_location_failure_reported.add(conf.get_bucket_host())
|
cls._bucket_location_failure_reported.add(conf.get_bucket_host())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
boto_session = boto3.Session(conf.key, conf.secret)
|
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
|
||||||
boto_resource = boto_session.resource('s3')
|
boto_resource = boto_session.resource('s3')
|
||||||
return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"]
|
return boto_resource.meta.client.get_bucket_location(Bucket=conf.bucket)["LocationConstraint"]
|
||||||
|
|
||||||
@ -1803,7 +1806,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_blob_from_data(
|
def create_blob_from_data(
|
||||||
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None
|
self, container_name, object_name, blob_name, data, max_connections=2,
|
||||||
|
progress_callback=None, content_settings=None
|
||||||
):
|
):
|
||||||
if self.__legacy:
|
if self.__legacy:
|
||||||
self.__blob_service.create_blob_from_bytes(
|
self.__blob_service.create_blob_from_bytes(
|
||||||
@ -1815,7 +1819,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||||
client.upload_blob(data, overwrite=True, max_concurrency=max_connections)
|
client.upload_blob(
|
||||||
|
data, overwrite=True,
|
||||||
|
max_concurrency=max_connections,
|
||||||
|
content_settings=content_settings,
|
||||||
|
)
|
||||||
|
|
||||||
def create_blob_from_path(
|
def create_blob_from_path(
|
||||||
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
|
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
|
||||||
@ -1831,7 +1839,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.create_blob_from_data(
|
self.create_blob_from_data(
|
||||||
container_name, None, blob_name, open(path, "rb"), max_connections=max_connections
|
container_name, None, blob_name, open(path, "rb"),
|
||||||
|
max_connections=max_connections,
|
||||||
|
content_settings=content_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_blob(self, container_name, blob_name):
|
def delete_blob(self, container_name, blob_name):
|
||||||
@ -1949,14 +1959,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
|||||||
stream = None
|
stream = None
|
||||||
try:
|
try:
|
||||||
from azure.storage.blob import ContentSettings # noqa
|
from azure.storage.blob import ContentSettings # noqa
|
||||||
from mimetypes import guess_type
|
|
||||||
|
|
||||||
container.create_blob_from_path(
|
container.create_blob_from_path(
|
||||||
container.name,
|
container.name,
|
||||||
blob_name,
|
blob_name,
|
||||||
file_path,
|
file_path,
|
||||||
max_connections=2,
|
max_connections=2,
|
||||||
content_settings=ContentSettings(content_type=guess_type(file_path)),
|
content_settings=ContentSettings(content_type=get_file_mimetype(file_path)),
|
||||||
progress_callback=callback,
|
progress_callback=callback,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
@ -2617,6 +2626,27 @@ class _FileStorageDriver(_Driver):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_mimetype(file_path):
|
||||||
|
"""
|
||||||
|
Get MIME types of a file
|
||||||
|
|
||||||
|
:param file_path: Path of the local file
|
||||||
|
:type file_path: str
|
||||||
|
|
||||||
|
:return: File MIME type. Return None if failed to get it
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
file_path = Path(file_path).resolve()
|
||||||
|
mimetype, _ = mimetypes.guess_type(file_path)
|
||||||
|
if mimetype:
|
||||||
|
return mimetype
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return 'binary/octet-stream'
|
||||||
|
|
||||||
|
|
||||||
driver_schemes = set(
|
driver_schemes = set(
|
||||||
filter(
|
filter(
|
||||||
None,
|
None,
|
||||||
|
Loading…
Reference in New Issue
Block a user