From c75c83c21da3af4629957ecb872354cb77bfff5d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 29 Dec 2022 13:18:39 +0200 Subject: [PATCH] Fix casting None to int fails uploads and permission checks --- clearml/storage/helper.py | 48 +++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/clearml/storage/helper.py b/clearml/storage/helper.py index c2a802d3..cd129406 100644 --- a/clearml/storage/helper.py +++ b/clearml/storage/helper.py @@ -602,7 +602,7 @@ class StorageHelper(object): return self._get_object_size_bytes(obj) def _get_object_size_bytes(self, obj): - # type: (object, bool) -> [int, None] + # type: (object) -> [int, None] """ Auxiliary function for `get_object_size_bytes`. Get size of the remote object in bytes. @@ -1606,6 +1606,7 @@ class _Boto3Driver(_Driver): def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): import boto3.s3.transfer stream = _Stream(iterator) + extra_args = {} try: extra_args = { 'ContentType': get_file_mimetype(object_name) @@ -1630,7 +1631,7 @@ class _Boto3Driver(_Driver): num_download_attempts=container.config.retries, ), Callback=callback, - ExtraArgs=extra_args, + ExtraArgs=extra_args ) except Exception as ex: self.get_logger().error("Failed uploading: %s" % ex) @@ -1642,6 +1643,7 @@ class _Boto3Driver(_Driver): def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs): import boto3.s3.transfer + extra_args = {} try: extra_args = { 'ContentType': get_file_mimetype(object_name or file_path) @@ -1665,7 +1667,7 @@ class _Boto3Driver(_Driver): use_threads=False, num_download_attempts=container.config.retries ), Callback=callback, - ExtraArgs=extra_args, + ExtraArgs=extra_args ) except Exception as ex: self.get_logger().error("Failed uploading: %s" % ex) @@ -2006,7 +2008,7 @@ class _AzureBlobServiceStorageDriver(_Driver): scheme = "azure" _containers = {} - _max_connections = deferred_config("azure.storage.max_connections", None) + _max_connections = deferred_config("azure.storage.max_connections", 0) class _Container(object): def __init__(self, name, config, account_url): @@ -2047,45 +2049,52 @@ class _AzureBlobServiceStorageDriver(_Driver): max_single_put_size=self.MAX_SINGLE_PUT_SIZE, ) + @staticmethod + def _get_max_connections_dict(max_connections=None, key="max_connections"): + # must cast for deferred resolving + try: + max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) + except (AttributeError, TypeError): + return {} + return {key: int(max_connections)} if max_connections else {} + def create_blob_from_data( self, container_name, object_name, blob_name, data, max_connections=None, progress_callback=None, content_settings=None ): - max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) if self.__legacy: self.__blob_service.create_blob_from_bytes( container_name, object_name, data, - max_connections=max_connections, progress_callback=progress_callback, + **self._get_max_connections_dict(max_connections) ) else: client = self.__blob_service.get_blob_client(container_name, blob_name) client.upload_blob( data, overwrite=True, - max_concurrency=max_connections, content_settings=content_settings, + **self._get_max_connections_dict(max_connections) ) def create_blob_from_path( self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None ): - max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) if self.__legacy: self.__blob_service.create_blob_from_path( container_name, blob_name, path, - max_connections=max_connections, content_settings=content_settings, progress_callback=progress_callback, + **self._get_max_connections_dict(max_connections) ) else: self.create_blob_from_data( container_name, None, blob_name, open(path, "rb"), - max_connections=max_connections, content_settings=content_settings, + **self._get_max_connections_dict(max_connections) ) def delete_blob(self, container_name, blob_name): @@ -2131,19 +2140,20 @@ class _AzureBlobServiceStorageDriver(_Driver): return client.download_blob().content_as_bytes() def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None): - max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) if self.__legacy: return self.__blob_service.get_blob_to_path( container_name, blob_name, path, - max_connections=max_connections, progress_callback=progress_callback, + **self._get_max_connections_dict(max_connections) ) else: client = self.__blob_service.get_blob_client(container_name, blob_name) with open(path, "wb") as file: - return client.download_blob(max_concurrency=max_connections).download_to_stream(file) + return client.download_blob( + **self._get_max_connections_dict(max_connections, "max_concurrency") + ).download_to_stream(file) def is_legacy(self): return self.__legacy @@ -2193,7 +2203,9 @@ class _AzureBlobServiceStorageDriver(_Driver): self.get_logger().error("Failed uploading: %s" % ex) return False - def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs): + def upload_object( + self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs + ): try: from azure.common import AzureHttpError # noqa except ImportError: @@ -2202,7 +2214,6 @@ class _AzureBlobServiceStorageDriver(_Driver): AzureHttpError = HttpResponseError # noqa blob_name = self._blob_name_from_object_path(object_name, container.name) - stream = None try: from azure.storage.blob import ContentSettings # noqa @@ -2219,9 +2230,6 @@ class _AzureBlobServiceStorageDriver(_Driver): self.get_logger().error('Failed uploading (Azure error): %s' % ex) except Exception as ex: self.get_logger().error('Failed uploading: %s' % ex) - finally: - if stream: - stream.close() def list_container_objects(self, container, ex_prefix=None, **kwargs): return list(container.list_blobs(container_name=container.name, prefix=ex_prefix)) @@ -2264,7 +2272,9 @@ class _AzureBlobServiceStorageDriver(_Driver): else: return blob - def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_): + def download_object( + self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_ + ): p = Path(local_path) if not overwrite_existing and p.is_file(): self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p))