diff --git a/clearml/storage/helper.py b/clearml/storage/helper.py index 8e0ecb6e..30b0cb6a 100644 --- a/clearml/storage/helper.py +++ b/clearml/storage/helper.py @@ -301,7 +301,7 @@ class StorageHelper(object): ) self._driver = _AzureBlobServiceStorageDriver() - self._container = self._driver.get_container(config=self._conf) + self._container = self._driver.get_container(config=self._conf, account_url=parsed.netloc) elif self._scheme == _Boto3Driver.scheme: self._conf = copy(self._s3_configurations.get_config_by_uri(url)) @@ -1694,22 +1694,143 @@ class _AzureBlobServiceStorageDriver(_Driver): _containers = {} class _Container(object): - def __init__(self, name, config): - try: - from azure.common import AzureHttpError # noqa: F401 - from azure.storage.blob import BlockBlobService - except ImportError: - raise UsageError( - 'Azure blob storage driver not found. ' - 'Please install driver using: pip install \"azure.storage.blob<=2.1.0\"' - ) - + def __init__(self, name, config, account_url): + self.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024 + self.SOCKET_TIMEOUT = (300, 2000) self.name = name self.config = config - self.blob_service = BlockBlobService( - account_name=config.account_name, - account_key=config.account_key, - ) + self.account_url = account_url + try: + from azure.storage.blob import BlobServiceClient # noqa + + self.__legacy = False + except ImportError: + try: + from azure.storage.blob import BlockBlobService # noqa + from azure.common import AzureHttpError # noqa: F401 + + self.__legacy = True + except ImportError: + raise UsageError( + "Azure blob storage driver not found. " + "Please install driver using: 'pip install clearml[azure]' or " + "pip install '\"azure.storage.blob>=12.0.0\"'" + ) + + if self.__legacy: + self.__blob_service = BlockBlobService( + account_name=self.config.account_name, + account_key=self.config.account_key, + ) + self.__blob_service.MAX_SINGLE_PUT_SIZE = self.MAX_SINGLE_PUT_SIZE + self.__blob_service.socket_timeout = self.SOCKET_TIMEOUT + else: + credential = {"account_name": self.config.account_name, "account_key": self.config.account_key} + self.__blob_service = BlobServiceClient( + account_url=account_url, + credential=credential, + max_single_put_size=self.MAX_SINGLE_PUT_SIZE, + ) + + def create_blob_from_data( + self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None + ): + if self.__legacy: + self.__blob_service.create_blob_from_bytes( + container_name, + object_name, + data, + max_connections=max_connections, + progress_callback=progress_callback, + ) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + client.upload_blob(data, overwrite=True, max_concurrency=max_connections) + + def create_blob_from_path( + self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None + ): + if self.__legacy: + self.__blob_service.create_blob_from_path( + container_name, + blob_name, + path, + max_connections=max_connections, + content_settings=content_settings, + progress_callback=progress_callback, + ) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + with open(path, "rb") as file: + first_chunk = True + for chunk in iter((lambda: file.read(self.MAX_SINGLE_PUT_SIZE)), b""): + if first_chunk: + client.upload_blob(chunk, overwrite=True, max_concurrency=max_connections) + first_chunk = False + else: + from azure.storage.blob import BlockType # noqa + + client.upload_blob(chunk, BlockType.AppendBlob) + + def delete_blob(self, container_name, blob_name): + if self.__legacy: + self.__blob_service.delete_blob( + container_name, + blob_name, + ) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + client.delete_blob() + + def exists(self, container_name, blob_name): + if self.__legacy: + return not self.__blob_service.exists(container_name, blob_name) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + return client.exists() + + def list_blobs(self, container_name, prefix=None): + if self.__legacy: + return self.__blob_service.list_blobs(container_name=container_name, prefix=prefix) + else: + client = self.__blob_service.get_container_client(container_name) + return client.list_blobs(name_starts_with=prefix) + + def get_blob_properties(self, container_name, blob_name): + if self.__legacy: + return self.__blob_service.get_blob_properties(container_name, blob_name) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + return client.get_blob_properties() + + def get_blob_to_bytes(self, container_name, blob_name, progress_callback=None): + if self.__legacy: + return self.__blob_service.get_blob_to_bytes( + container_name, + blob_name, + progress_callback=progress_callback, + ) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name) + return client.download_blob().content_as_bytes() + + def get_blob_to_path(self, container_name, blob_name, path, max_connections=10, progress_callback=None): + if self.__legacy: + return self.__blob_service.get_blob_to_path( + container_name, + blob_name, + path, + max_connections=max_connections, + progress_callback=progress_callback, + ) + else: + client = self.__blob_service.get_blob_client(container_name, blob_name, max_concurrency=max_connections) + with open(path, "wb") as file: + return client.download_blob().download_to_stream(file) + + def is_legacy(self): + return self.__legacy + @attrs class _Object(object): @@ -1717,28 +1838,33 @@ class _AzureBlobServiceStorageDriver(_Driver): blob_name = attrib() content_length = attrib() - def get_container(self, container_name=None, config=None, **kwargs): + def get_container(self, container_name=None, config=None, account_url=None, **kwargs): container_name = container_name or config.container_name if container_name not in self._containers: - self._containers[container_name] = self._Container(name=container_name, config=config) + self._containers[container_name] = self._Container( + name=container_name, config=config, account_url=account_url + ) # self._containers[container_name].config.retries = kwargs.get('retries', 5) return self._containers[container_name] def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): - from azure.common import AzureHttpError # noqa + try: + from azure.common import AzureHttpError # noqa + except ImportError: + from azure.core.exceptions import HttpResponseError # noqa + + AzureHttpError = HttpResponseError # noqa blob_name = self._blob_name_from_object_path(object_name, container.name) # noqa: F841 try: - container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024 - container.blob_service.socket_timeout = (300, 2000) - container.blob_service.create_blob_from_bytes( + container.create_blob_from_data( container.name, object_name, + blob_name, iterator.read() if hasattr(iterator, "read") else bytes(iterator), - # timeout=300, max_connections=2, progress_callback=callback, - ) + ) return True except AzureHttpError as ex: self.get_logger().error('Failed uploading (Azure error): %s' % ex) @@ -1747,20 +1873,23 @@ class _AzureBlobServiceStorageDriver(_Driver): return False def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs): - from azure.common import AzureHttpError # noqa + try: + from azure.common import AzureHttpError # noqa + except ImportError: + from azure.core.exceptions import HttpResponseError # noqa + + AzureHttpError = HttpResponseError # noqa blob_name = self._blob_name_from_object_path(object_name, container.name) stream = None try: from azure.storage.blob import ContentSettings # noqa from mimetypes import guess_type - container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024 - container.blob_service.socket_timeout = (300, 2000) - container.blob_service.create_blob_from_path( + + container.create_blob_from_path( container.name, blob_name, file_path, - # timeout=300, max_connections=2, content_settings=ContentSettings(content_type=guess_type(file_path)), progress_callback=callback, @@ -1775,15 +1904,15 @@ class _AzureBlobServiceStorageDriver(_Driver): stream.close() def list_container_objects(self, container, ex_prefix=None, **kwargs): - return list(container.blob_service.list_blobs(container_name=container.name, prefix=ex_prefix)) + return list(container.list_blobs(container_name=container.name, prefix=ex_prefix)) def delete_object(self, object, **kwargs): container = object.container - container.blob_service.delete_blob( + container.delete_blob( container.name, object.blob_name, ) - return not object.container.blob_service.exists(container.name, object.blob_name) + return not object.container.exists(container.name, object.blob_name) def get_object(self, container_name, object_name, *args, **kwargs): container = self._containers.get(container_name) @@ -1791,21 +1920,21 @@ class _AzureBlobServiceStorageDriver(_Driver): raise StorageError("Container `{}` not found for object {}".format(container_name, object_name)) # blob_name = self._blob_name_from_object_path(object_name, container_name) - blob = container.blob_service.get_blob_properties(container.name, object_name) + blob = container.get_blob_properties(container.name, object_name) - return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length) + if container.is_legacy(): + return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length) + else: + return self._Object(container=container, blob_name=blob.name, content_length=blob.size) def download_object_as_stream(self, obj, verbose, *_, **__): container = obj.container total_size_mb = obj.content_length / (1024. * 1024.) remote_path = os.path.join( - "{}://".format(self.scheme), - container.config.account_name, - container.name, - obj.blob_name + "{}://".format(self.scheme), container.config.account_name, container.name, obj.blob_name ) cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self.get_logger()) - blob = container.blob_service.get_blob_to_bytes( + blob = container.get_blob_to_bytes( container.name, obj.blob_name, progress_callback=cb, @@ -1815,12 +1944,13 @@ class _AzureBlobServiceStorageDriver(_Driver): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_): p = Path(local_path) if not overwrite_existing and p.is_file(): - self.get_logger().warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) + self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p)) return download_done = threading.Event() download_done.counter = 0 + def callback_func(current, total): if callback: chunk = current - download_done.counter @@ -1829,9 +1959,10 @@ class _AzureBlobServiceStorageDriver(_Driver): if current >= total: download_done.set() + container = obj.container container.blob_service.MAX_SINGLE_GET_SIZE = 5 * 1024 * 1024 - _ = container.blob_service.get_blob_to_path( + _ = container.get_blob_to_path( container.name, obj.blob_name, local_path, diff --git a/setup.py b/setup.py index 858b8058..dba5eaec 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ setup( 'boto3>=1.9', ], 'azure': [ - 'azure-storage-blob>=2.0.1,<=2.1', + 'azure-storage-blob>=12.0.0', ], 'gs': [ 'google-cloud-storage>=1.13.2',