mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Support newer azure storage python version (#548)
Add support for Azure Blob Storage >=12.0.0. Backwards compatibility remains for >=2.0.1,<=2.1
This commit is contained in:
parent
c01e2e1166
commit
5be1c8b2b0
@ -301,7 +301,7 @@ class StorageHelper(object):
|
||||
)
|
||||
|
||||
self._driver = _AzureBlobServiceStorageDriver()
|
||||
self._container = self._driver.get_container(config=self._conf)
|
||||
self._container = self._driver.get_container(config=self._conf, account_url=parsed.netloc)
|
||||
|
||||
elif self._scheme == _Boto3Driver.scheme:
|
||||
self._conf = copy(self._s3_configurations.get_config_by_uri(url))
|
||||
@ -1694,22 +1694,143 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
_containers = {}
|
||||
|
||||
class _Container(object):
|
||||
def __init__(self, name, config):
|
||||
try:
|
||||
from azure.common import AzureHttpError # noqa: F401
|
||||
from azure.storage.blob import BlockBlobService
|
||||
except ImportError:
|
||||
raise UsageError(
|
||||
'Azure blob storage driver not found. '
|
||||
'Please install driver using: pip install \"azure.storage.blob<=2.1.0\"'
|
||||
)
|
||||
|
||||
def __init__(self, name, config, account_url):
|
||||
self.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
|
||||
self.SOCKET_TIMEOUT = (300, 2000)
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.blob_service = BlockBlobService(
|
||||
account_name=config.account_name,
|
||||
account_key=config.account_key,
|
||||
)
|
||||
self.account_url = account_url
|
||||
try:
|
||||
from azure.storage.blob import BlobServiceClient # noqa
|
||||
|
||||
self.__legacy = False
|
||||
except ImportError:
|
||||
try:
|
||||
from azure.storage.blob import BlockBlobService # noqa
|
||||
from azure.common import AzureHttpError # noqa: F401
|
||||
|
||||
self.__legacy = True
|
||||
except ImportError:
|
||||
raise UsageError(
|
||||
"Azure blob storage driver not found. "
|
||||
"Please install driver using: 'pip install clearml[azure]' or "
|
||||
"pip install '\"azure.storage.blob>=12.0.0\"'"
|
||||
)
|
||||
|
||||
if self.__legacy:
|
||||
self.__blob_service = BlockBlobService(
|
||||
account_name=self.config.account_name,
|
||||
account_key=self.config.account_key,
|
||||
)
|
||||
self.__blob_service.MAX_SINGLE_PUT_SIZE = self.MAX_SINGLE_PUT_SIZE
|
||||
self.__blob_service.socket_timeout = self.SOCKET_TIMEOUT
|
||||
else:
|
||||
credential = {"account_name": self.config.account_name, "account_key": self.config.account_key}
|
||||
self.__blob_service = BlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=credential,
|
||||
max_single_put_size=self.MAX_SINGLE_PUT_SIZE,
|
||||
)
|
||||
|
||||
def create_blob_from_data(
|
||||
self, container_name, object_name, blob_name, data, max_connections=2, progress_callback=None
|
||||
):
|
||||
if self.__legacy:
|
||||
self.__blob_service.create_blob_from_bytes(
|
||||
container_name,
|
||||
object_name,
|
||||
data,
|
||||
max_connections=max_connections,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
client.upload_blob(data, overwrite=True, max_concurrency=max_connections)
|
||||
|
||||
def create_blob_from_path(
|
||||
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
|
||||
):
|
||||
if self.__legacy:
|
||||
self.__blob_service.create_blob_from_path(
|
||||
container_name,
|
||||
blob_name,
|
||||
path,
|
||||
max_connections=max_connections,
|
||||
content_settings=content_settings,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
with open(path, "rb") as file:
|
||||
first_chunk = True
|
||||
for chunk in iter((lambda: file.read(self.MAX_SINGLE_PUT_SIZE)), b""):
|
||||
if first_chunk:
|
||||
client.upload_blob(chunk, overwrite=True, max_concurrency=max_connections)
|
||||
first_chunk = False
|
||||
else:
|
||||
from azure.storage.blob import BlockType # noqa
|
||||
|
||||
client.upload_blob(chunk, BlockType.AppendBlob)
|
||||
|
||||
def delete_blob(self, container_name, blob_name):
|
||||
if self.__legacy:
|
||||
self.__blob_service.delete_blob(
|
||||
container_name,
|
||||
blob_name,
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
client.delete_blob()
|
||||
|
||||
def exists(self, container_name, blob_name):
|
||||
if self.__legacy:
|
||||
return not self.__blob_service.exists(container_name, blob_name)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
return client.exists()
|
||||
|
||||
def list_blobs(self, container_name, prefix=None):
|
||||
if self.__legacy:
|
||||
return self.__blob_service.list_blobs(container_name=container_name, prefix=prefix)
|
||||
else:
|
||||
client = self.__blob_service.get_container_client(container_name)
|
||||
return client.list_blobs(name_starts_with=prefix)
|
||||
|
||||
def get_blob_properties(self, container_name, blob_name):
|
||||
if self.__legacy:
|
||||
return self.__blob_service.get_blob_properties(container_name, blob_name)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
return client.get_blob_properties()
|
||||
|
||||
def get_blob_to_bytes(self, container_name, blob_name, progress_callback=None):
|
||||
if self.__legacy:
|
||||
return self.__blob_service.get_blob_to_bytes(
|
||||
container_name,
|
||||
blob_name,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
return client.download_blob().content_as_bytes()
|
||||
|
||||
def get_blob_to_path(self, container_name, blob_name, path, max_connections=10, progress_callback=None):
|
||||
if self.__legacy:
|
||||
return self.__blob_service.get_blob_to_path(
|
||||
container_name,
|
||||
blob_name,
|
||||
path,
|
||||
max_connections=max_connections,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
else:
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name, max_concurrency=max_connections)
|
||||
with open(path, "wb") as file:
|
||||
return client.download_blob().download_to_stream(file)
|
||||
|
||||
def is_legacy(self):
|
||||
return self.__legacy
|
||||
|
||||
|
||||
@attrs
|
||||
class _Object(object):
|
||||
@ -1717,28 +1838,33 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
blob_name = attrib()
|
||||
content_length = attrib()
|
||||
|
||||
def get_container(self, container_name=None, config=None, **kwargs):
|
||||
def get_container(self, container_name=None, config=None, account_url=None, **kwargs):
|
||||
container_name = container_name or config.container_name
|
||||
if container_name not in self._containers:
|
||||
self._containers[container_name] = self._Container(name=container_name, config=config)
|
||||
self._containers[container_name] = self._Container(
|
||||
name=container_name, config=config, account_url=account_url
|
||||
)
|
||||
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
|
||||
return self._containers[container_name]
|
||||
|
||||
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
|
||||
from azure.common import AzureHttpError # noqa
|
||||
try:
|
||||
from azure.common import AzureHttpError # noqa
|
||||
except ImportError:
|
||||
from azure.core.exceptions import HttpResponseError # noqa
|
||||
|
||||
AzureHttpError = HttpResponseError # noqa
|
||||
|
||||
blob_name = self._blob_name_from_object_path(object_name, container.name) # noqa: F841
|
||||
try:
|
||||
container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
|
||||
container.blob_service.socket_timeout = (300, 2000)
|
||||
container.blob_service.create_blob_from_bytes(
|
||||
container.create_blob_from_data(
|
||||
container.name,
|
||||
object_name,
|
||||
blob_name,
|
||||
iterator.read() if hasattr(iterator, "read") else bytes(iterator),
|
||||
# timeout=300,
|
||||
max_connections=2,
|
||||
progress_callback=callback,
|
||||
)
|
||||
)
|
||||
return True
|
||||
except AzureHttpError as ex:
|
||||
self.get_logger().error('Failed uploading (Azure error): %s' % ex)
|
||||
@ -1747,20 +1873,23 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
return False
|
||||
|
||||
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs):
|
||||
from azure.common import AzureHttpError # noqa
|
||||
try:
|
||||
from azure.common import AzureHttpError # noqa
|
||||
except ImportError:
|
||||
from azure.core.exceptions import HttpResponseError # noqa
|
||||
|
||||
AzureHttpError = HttpResponseError # noqa
|
||||
|
||||
blob_name = self._blob_name_from_object_path(object_name, container.name)
|
||||
stream = None
|
||||
try:
|
||||
from azure.storage.blob import ContentSettings # noqa
|
||||
from mimetypes import guess_type
|
||||
container.blob_service.MAX_SINGLE_PUT_SIZE = 16 * 1024 * 1024
|
||||
container.blob_service.socket_timeout = (300, 2000)
|
||||
container.blob_service.create_blob_from_path(
|
||||
|
||||
container.create_blob_from_path(
|
||||
container.name,
|
||||
blob_name,
|
||||
file_path,
|
||||
# timeout=300,
|
||||
max_connections=2,
|
||||
content_settings=ContentSettings(content_type=guess_type(file_path)),
|
||||
progress_callback=callback,
|
||||
@ -1775,15 +1904,15 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
stream.close()
|
||||
|
||||
def list_container_objects(self, container, ex_prefix=None, **kwargs):
|
||||
return list(container.blob_service.list_blobs(container_name=container.name, prefix=ex_prefix))
|
||||
return list(container.list_blobs(container_name=container.name, prefix=ex_prefix))
|
||||
|
||||
def delete_object(self, object, **kwargs):
|
||||
container = object.container
|
||||
container.blob_service.delete_blob(
|
||||
container.delete_blob(
|
||||
container.name,
|
||||
object.blob_name,
|
||||
)
|
||||
return not object.container.blob_service.exists(container.name, object.blob_name)
|
||||
return not object.container.exists(container.name, object.blob_name)
|
||||
|
||||
def get_object(self, container_name, object_name, *args, **kwargs):
|
||||
container = self._containers.get(container_name)
|
||||
@ -1791,21 +1920,21 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
raise StorageError("Container `{}` not found for object {}".format(container_name, object_name))
|
||||
|
||||
# blob_name = self._blob_name_from_object_path(object_name, container_name)
|
||||
blob = container.blob_service.get_blob_properties(container.name, object_name)
|
||||
blob = container.get_blob_properties(container.name, object_name)
|
||||
|
||||
return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length)
|
||||
if container.is_legacy():
|
||||
return self._Object(container=container, blob_name=blob.name, content_length=blob.properties.content_length)
|
||||
else:
|
||||
return self._Object(container=container, blob_name=blob.name, content_length=blob.size)
|
||||
|
||||
def download_object_as_stream(self, obj, verbose, *_, **__):
|
||||
container = obj.container
|
||||
total_size_mb = obj.content_length / (1024. * 1024.)
|
||||
remote_path = os.path.join(
|
||||
"{}://".format(self.scheme),
|
||||
container.config.account_name,
|
||||
container.name,
|
||||
obj.blob_name
|
||||
"{}://".format(self.scheme), container.config.account_name, container.name, obj.blob_name
|
||||
)
|
||||
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self.get_logger())
|
||||
blob = container.blob_service.get_blob_to_bytes(
|
||||
blob = container.get_blob_to_bytes(
|
||||
container.name,
|
||||
obj.blob_name,
|
||||
progress_callback=cb,
|
||||
@ -1815,12 +1944,13 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
self.get_logger().warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p))
|
||||
return
|
||||
|
||||
download_done = threading.Event()
|
||||
download_done.counter = 0
|
||||
|
||||
|
||||
def callback_func(current, total):
|
||||
if callback:
|
||||
chunk = current - download_done.counter
|
||||
@ -1829,9 +1959,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
if current >= total:
|
||||
download_done.set()
|
||||
|
||||
|
||||
container = obj.container
|
||||
container.blob_service.MAX_SINGLE_GET_SIZE = 5 * 1024 * 1024
|
||||
_ = container.blob_service.get_blob_to_path(
|
||||
_ = container.get_blob_to_path(
|
||||
container.name,
|
||||
obj.blob_name,
|
||||
local_path,
|
||||
|
Loading…
Reference in New Issue
Block a user