mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix casting None to int fails uploads and permission checks
This commit is contained in:
		
							parent
							
								
									a3d44aa81f
								
							
						
					
					
						commit
						c75c83c21d
					
				| @ -602,7 +602,7 @@ class StorageHelper(object): | ||||
|         return self._get_object_size_bytes(obj) | ||||
| 
 | ||||
|     def _get_object_size_bytes(self, obj): | ||||
|         # type: (object, bool) -> [int, None] | ||||
|         # type: (object) -> [int, None] | ||||
|         """ | ||||
|         Auxiliary function for `get_object_size_bytes`. | ||||
|         Get size of the remote object in bytes. | ||||
| @ -1606,6 +1606,7 @@ class _Boto3Driver(_Driver): | ||||
|     def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): | ||||
|         import boto3.s3.transfer | ||||
|         stream = _Stream(iterator) | ||||
|         extra_args = {} | ||||
|         try: | ||||
|             extra_args = { | ||||
|                 'ContentType': get_file_mimetype(object_name) | ||||
| @ -1630,7 +1631,7 @@ class _Boto3Driver(_Driver): | ||||
|                         num_download_attempts=container.config.retries, | ||||
|                     ), | ||||
|                     Callback=callback, | ||||
|                     ExtraArgs=extra_args, | ||||
|                     ExtraArgs=extra_args | ||||
|                 ) | ||||
|             except Exception as ex: | ||||
|                 self.get_logger().error("Failed uploading: %s" % ex) | ||||
| @ -1642,6 +1643,7 @@ class _Boto3Driver(_Driver): | ||||
| 
 | ||||
|     def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs): | ||||
|         import boto3.s3.transfer | ||||
|         extra_args = {} | ||||
|         try: | ||||
|             extra_args = { | ||||
|                 'ContentType': get_file_mimetype(object_name or file_path) | ||||
| @ -1665,7 +1667,7 @@ class _Boto3Driver(_Driver): | ||||
|                         use_threads=False, num_download_attempts=container.config.retries | ||||
|                     ), | ||||
|                     Callback=callback, | ||||
|                     ExtraArgs=extra_args, | ||||
|                     ExtraArgs=extra_args | ||||
|                 ) | ||||
|             except Exception as ex: | ||||
|                 self.get_logger().error("Failed uploading: %s" % ex) | ||||
| @ -2006,7 +2008,7 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|     scheme = "azure" | ||||
| 
 | ||||
|     _containers = {} | ||||
|     _max_connections = deferred_config("azure.storage.max_connections", None) | ||||
|     _max_connections = deferred_config("azure.storage.max_connections", 0) | ||||
| 
 | ||||
|     class _Container(object): | ||||
|         def __init__(self, name, config, account_url): | ||||
| @ -2047,45 +2049,52 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|                     max_single_put_size=self.MAX_SINGLE_PUT_SIZE, | ||||
|                 ) | ||||
| 
 | ||||
|         @staticmethod | ||||
|         def _get_max_connections_dict(max_connections=None, key="max_connections"): | ||||
|             # must cast for deferred resolving | ||||
|             try: | ||||
|                 max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) | ||||
|             except (AttributeError, TypeError): | ||||
|                 return {} | ||||
|             return {key: int(max_connections)} if max_connections else {} | ||||
| 
 | ||||
|         def create_blob_from_data( | ||||
|             self, container_name, object_name, blob_name, data, max_connections=None, | ||||
|                 progress_callback=None, content_settings=None | ||||
|         ): | ||||
|             max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) | ||||
|             if self.__legacy: | ||||
|                 self.__blob_service.create_blob_from_bytes( | ||||
|                     container_name, | ||||
|                     object_name, | ||||
|                     data, | ||||
|                     max_connections=max_connections, | ||||
|                     progress_callback=progress_callback, | ||||
|                     **self._get_max_connections_dict(max_connections) | ||||
|                 ) | ||||
|             else: | ||||
|                 client = self.__blob_service.get_blob_client(container_name, blob_name) | ||||
|                 client.upload_blob( | ||||
|                     data, overwrite=True, | ||||
|                     max_concurrency=max_connections, | ||||
|                     content_settings=content_settings, | ||||
|                     **self._get_max_connections_dict(max_connections) | ||||
|                 ) | ||||
| 
 | ||||
|         def create_blob_from_path( | ||||
|             self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None | ||||
|         ): | ||||
|             max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) | ||||
|             if self.__legacy: | ||||
|                 self.__blob_service.create_blob_from_path( | ||||
|                     container_name, | ||||
|                     blob_name, | ||||
|                     path, | ||||
|                     max_connections=max_connections, | ||||
|                     content_settings=content_settings, | ||||
|                     progress_callback=progress_callback, | ||||
|                     **self._get_max_connections_dict(max_connections) | ||||
|                 ) | ||||
|             else: | ||||
|                 self.create_blob_from_data( | ||||
|                     container_name, None, blob_name, open(path, "rb"), | ||||
|                     max_connections=max_connections, | ||||
|                     content_settings=content_settings, | ||||
|                     **self._get_max_connections_dict(max_connections) | ||||
|                 ) | ||||
| 
 | ||||
|         def delete_blob(self, container_name, blob_name): | ||||
| @ -2131,19 +2140,20 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|                 return client.download_blob().content_as_bytes() | ||||
| 
 | ||||
|         def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None): | ||||
|             max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections) | ||||
|             if self.__legacy: | ||||
|                 return self.__blob_service.get_blob_to_path( | ||||
|                     container_name, | ||||
|                     blob_name, | ||||
|                     path, | ||||
|                     max_connections=max_connections, | ||||
|                     progress_callback=progress_callback, | ||||
|                     **self._get_max_connections_dict(max_connections) | ||||
|                 ) | ||||
|             else: | ||||
|                 client = self.__blob_service.get_blob_client(container_name, blob_name) | ||||
|                 with open(path, "wb") as file: | ||||
|                     return client.download_blob(max_concurrency=max_connections).download_to_stream(file) | ||||
|                     return client.download_blob( | ||||
|                         **self._get_max_connections_dict(max_connections, "max_concurrency") | ||||
|                     ).download_to_stream(file) | ||||
| 
 | ||||
|         def is_legacy(self): | ||||
|             return self.__legacy | ||||
| @ -2193,7 +2203,9 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|             self.get_logger().error("Failed uploading: %s" % ex) | ||||
|         return False | ||||
| 
 | ||||
|     def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs): | ||||
|     def upload_object( | ||||
|         self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs | ||||
|     ): | ||||
|         try: | ||||
|             from azure.common import AzureHttpError  # noqa | ||||
|         except ImportError: | ||||
| @ -2202,7 +2214,6 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|             AzureHttpError = HttpResponseError  # noqa | ||||
| 
 | ||||
|         blob_name = self._blob_name_from_object_path(object_name, container.name) | ||||
|         stream = None | ||||
|         try: | ||||
|             from azure.storage.blob import ContentSettings  # noqa | ||||
| 
 | ||||
| @ -2219,9 +2230,6 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|             self.get_logger().error('Failed uploading (Azure error): %s' % ex) | ||||
|         except Exception as ex: | ||||
|             self.get_logger().error('Failed uploading: %s' % ex) | ||||
|         finally: | ||||
|             if stream: | ||||
|                 stream.close() | ||||
| 
 | ||||
|     def list_container_objects(self, container, ex_prefix=None, **kwargs): | ||||
|         return list(container.list_blobs(container_name=container.name, prefix=ex_prefix)) | ||||
| @ -2264,7 +2272,9 @@ class _AzureBlobServiceStorageDriver(_Driver): | ||||
|         else: | ||||
|             return blob | ||||
| 
 | ||||
|     def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_): | ||||
|     def download_object( | ||||
|         self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_ | ||||
|     ): | ||||
|         p = Path(local_path) | ||||
|         if not overwrite_existing and p.is_file(): | ||||
|             self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p)) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai