Support newer azure storage python version (#548)

Add support for Azure Blob Storage >=12.0.0. Backwards compatibility remains for >=2.0.1,<=2.1
This commit is contained in:
eugen-ajechiloae-clearml 2022-02-12 17:32:43 +02:00 committed by GitHub
parent c01e2e1166
commit 5be1c8b2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 172 additions and 41 deletions

View File

@ -301,7 +301,7 @@ class StorageHelper(object):
)
self._driver = _AzureBlobServiceStorageDriver()
self._container = self._driver.get_container(config=self._conf)
self._container = self._driver.get_container(config=self._conf, account_url=parsed.netloc)
elif self._scheme == _Boto3Driver.scheme:
self._conf = copy(self._s3_configurations.get_config_by_uri(url))
@ -1694,22 +1694,143 @@ class _AzureBlobServiceStorageDriver(_Driver):
_containers = {}
class _Container(object):
def __init__(self, name, config):
try:
from azure.common import AzureHttpError # noqa: F401
from azure.storage.blob import BlockBlobService
except ImportError:
raise UsageError(
'Azure blob storage driver not found. '
'Please install driver using: pip install \"azure.storage.blob<=2.1.0\"'
)
def __init__(self, name, config, account_url):
self.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
self.SOCKET_TIMEOUT = (300, 2000)
self.name = name
self.config = config
self.blob_service = BlockBlobService(
account_name=config.account_name,
account_key=config.account_key,
)
self.account_url = account_url
try:
from azure.storage.blob import BlobServiceClient # noqa
self.__legacy = False
except ImportError:
try:
from azure.storage.blob import BlockBlobService # noqa
from azure.common import AzureHttpError # noqa: F401
self.__legacy = True
except ImportError:
raise UsageError(
"Azure blob storage driver not found. "
"Please install driver using: 'pip install clearml[azure]' or "
"pip install '\"azure.storage.blob>=12.0.0\"'"
)
if self.__legacy:
self.__blob_service = BlockBlobService(
account_name=self.config.account_name,
account_key=self.config.account_key,
)
self.__blob_service.MAX_SINGLE_PUT_SIZE = self.MAX_SINGLE_PUT_SIZE
self.__blob_service.socket_timeout = self.SOCKET_TIMEOUT
else:
credential = {"account_name": self.config.account_name, "account_key": self.config.account_key}
self.__blob_service = BlobServiceClient(
account_url=account_url,
credential=credential,
max_single_put_size=self.MAX_SINGLE_PUT_SIZE,
)
def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None
):
if self.__legacy:
self.__blob_service.create_blob_from_bytes(
container_name,
object_name,
data,
max_connections=max_connections,
progress_callback=progress_callback,
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
client.upload_blob(data, overwrite=True, max_concurrency=max_connections)
def create_blob_from_path(
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
):
if self.__legacy:
self.__blob_service.create_blob_from_path(
container_name,
blob_name,
path,
max_connections=max_connections,
content_settings=content_settings,
progress_callback=progress_callback,
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
with open(path, "rb") as file:
first_chunk = True
for chunk in iter((lambda: file.read(self.MAX_SINGLE_PUT_SIZE)), b""):
if first_chunk:
client.upload_blob(chunk, overwrite=True, max_concurrency=max_connections)
first_chunk = False
else:
from azure.storage.blob import BlockType # noqa
client.upload_blob(chunk, BlockType.AppendBlob)
def delete_blob(self, container_name, blob_name):
if self.__legacy:
self.__blob_service.delete_blob(
container_name,
blob_name,
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
client.delete_blob()
def exists(self, container_name, blob_name):
if self.__legacy:
return not self.__blob_service.exists(container_name, blob_name)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
return client.exists()
def list_blobs(self, container_name, prefix=None):
if self.__legacy:
return self.__blob_service.list_blobs(container_name=container_name, prefix=prefix)
else:
client = self.__blob_service.get_container_client(container_name)
return client.list_blobs(name_starts_with=prefix)
def get_blob_properties(self, container_name, blob_name):
if self.__legacy:
return self.__blob_service.get_blob_properties(container_name, blob_name)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
return client.get_blob_properties()
def get_blob_to_bytes(self, container_name, blob_name, progress_callback=None):
if self.__legacy:
return self.__blob_service.get_blob_to_bytes(
container_name,
blob_name,
progress_callback=progress_callback,
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
return client.download_blob().content_as_bytes()
def get_blob_to_path(self, container_name, blob_name, path, max_connections=10, progress_callback=None):
if self.__legacy:
return self.__blob_service.get_blob_to_path(
container_name,
blob_name,
path,
max_connections=max_connections,
progress_callback=progress_callback,
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name, max_concurrency=max_connections)
with open(path, "wb") as file:
return client.download_blob().download_to_stream(file)
def is_legacy(self):
return self.__legacy
@attrs
class _Object(object):
@ -1717,28 +1838,33 @@ class _AzureBlobServiceStorageDriver(_Driver):
blob_name = attrib()
content_length = attrib()
def get_container(self, container_name=None, config=None, **kwargs):
def get_container(self, container_name=None, config=None, account_url=None, **kwargs):
container_name = container_name or config.container_name
if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, config=config)
self._containers[container_name] = self._Container(
name=container_name, config=config, account_url=account_url
)
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name]
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
from azure.common import AzureHttpError # noqa
try:
from azure.common import AzureHttpError # noqa
except ImportError:
from azure.core.exceptions import HttpResponseError # noqa
AzureHttpError = HttpResponseError # noqa
blob_name = self._blob_name_from_object_path(object_name, container.name) # noqa: F841
try:
container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
container.blob_service.socket_timeout = (300, 2000)
container.blob_service.create_blob_from_bytes(
container.create_blob_from_data(
container.name,
object_name,
blob_name,
iterator.read() if hasattr(iterator, "read") else bytes(iterator),
# timeout=300,
max_connections=2,
progress_callback=callback,
)
)
return True
except AzureHttpError as ex:
self.get_logger().error('Failed uploading (Azure error): %s' % ex)
@ -1747,20 +1873,23 @@ class _AzureBlobServiceStorageDriver(_Driver):
return False
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs):
from azure.common import AzureHttpError # noqa
try:
from azure.common import AzureHttpError # noqa
except ImportError:
from azure.core.exceptions import HttpResponseError # noqa
AzureHttpError = HttpResponseError # noqa
blob_name = self._blob_name_from_object_path(object_name, container.name)
stream = None
try:
from azure.storage.blob import ContentSettings # noqa
from mimetypes import guess_type
container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
container.blob_service.socket_timeout = (300, 2000)
container.blob_service.create_blob_from_path(
container.create_blob_from_path(
container.name,
blob_name,
file_path,
# timeout=300,
max_connections=2,
content_settings=ContentSettings(content_type=guess_type(file_path)),
progress_callback=callback,
@ -1775,15 +1904,15 @@ class _AzureBlobServiceStorageDriver(_Driver):
stream.close()
def list_container_objects(self, container, ex_prefix=None, **kwargs):
return list(container.blob_service.list_blobs(container_name=container.name, prefix=ex_prefix))
return list(container.list_blobs(container_name=container.name, prefix=ex_prefix))
def delete_object(self, object, **kwargs):
container = object.container
container.blob_service.delete_blob(
container.delete_blob(
container.name,
object.blob_name,
)
return not object.container.blob_service.exists(container.name, object.blob_name)
return not object.container.exists(container.name, object.blob_name)
def get_object(self, container_name, object_name, *args, **kwargs):
container = self._containers.get(container_name)
@ -1791,21 +1920,21 @@ class _AzureBlobServiceStorageDriver(_Driver):
raise StorageError("Container `{}` not found for object {}".format(container_name, object_name))
# blob_name = self._blob_name_from_object_path(object_name, container_name)
blob = container.blob_service.get_blob_properties(container.name, object_name)
blob = container.get_blob_properties(container.name, object_name)
return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length)
if container.is_legacy():
return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length)
else:
return self._Object(container=container, blob_name=blob.name, content_length=blob.size)
def download_object_as_stream(self, obj, verbose, *_, **__):
container = obj.container
total_size_mb = obj.content_length / (1024. * 1024.)
remote_path = os.path.join(
"{}://".format(self.scheme),
container.config.account_name,
container.name,
obj.blob_name
"{}://".format(self.scheme), container.config.account_name, container.name, obj.blob_name
)
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self.get_logger())
blob = container.blob_service.get_blob_to_bytes(
blob = container.get_blob_to_bytes(
container.name,
obj.blob_name,
progress_callback=cb,
@ -1815,12 +1944,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path)
if not overwrite_existing and p.is_file():
self.get_logger().warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p))
return
download_done = threading.Event()
download_done.counter = 0
def callback_func(current, total):
if callback:
chunk = current - download_done.counter
@ -1829,9 +1959,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
if current >= total:
download_done.set()
container = obj.container
container.blob_service.MAX_SINGLE_GET_SIZE = 5 * 1024 * 1024
_ = container.blob_service.get_blob_to_path(
_ = container.get_blob_to_path(
container.name,
obj.blob_name,
local_path,

View File

@ -71,7 +71,7 @@ setup(
'boto3>=1.9',
],
'azure': [
'azure-storage-blob>=2.0.1,<=2.1',
'azure-storage-blob>=12.0.0',
],
'gs': [
'google-cloud-storage>=1.13.2',