Fix casting None to int fails uploads and permission checks

This commit is contained in:
allegroai 2022-12-29 13:18:39 +02:00
parent a3d44aa81f
commit c75c83c21d

View File

@ -602,7 +602,7 @@ class StorageHelper(object):
return self._get_object_size_bytes(obj) return self._get_object_size_bytes(obj)
def _get_object_size_bytes(self, obj): def _get_object_size_bytes(self, obj):
# type: (object, bool) -> [int, None] # type: (object) -> [int, None]
""" """
Auxiliary function for `get_object_size_bytes`. Auxiliary function for `get_object_size_bytes`.
Get size of the remote object in bytes. Get size of the remote object in bytes.
@ -1606,6 +1606,7 @@ class _Boto3Driver(_Driver):
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer import boto3.s3.transfer
stream = _Stream(iterator) stream = _Stream(iterator)
extra_args = {}
try: try:
extra_args = { extra_args = {
'ContentType': get_file_mimetype(object_name) 'ContentType': get_file_mimetype(object_name)
@ -1630,7 +1631,7 @@ class _Boto3Driver(_Driver):
num_download_attempts=container.config.retries, num_download_attempts=container.config.retries,
), ),
Callback=callback, Callback=callback,
ExtraArgs=extra_args, ExtraArgs=extra_args
) )
except Exception as ex: except Exception as ex:
self.get_logger().error("Failed uploading: %s" % ex) self.get_logger().error("Failed uploading: %s" % ex)
@ -1642,6 +1643,7 @@ class _Boto3Driver(_Driver):
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs): def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer import boto3.s3.transfer
extra_args = {}
try: try:
extra_args = { extra_args = {
'ContentType': get_file_mimetype(object_name or file_path) 'ContentType': get_file_mimetype(object_name or file_path)
@ -1665,7 +1667,7 @@ class _Boto3Driver(_Driver):
use_threads=False, num_download_attempts=container.config.retries use_threads=False, num_download_attempts=container.config.retries
), ),
Callback=callback, Callback=callback,
ExtraArgs=extra_args, ExtraArgs=extra_args
) )
except Exception as ex: except Exception as ex:
self.get_logger().error("Failed uploading: %s" % ex) self.get_logger().error("Failed uploading: %s" % ex)
@ -2006,7 +2008,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
scheme = "azure" scheme = "azure"
_containers = {} _containers = {}
_max_connections = deferred_config("azure.storage.max_connections", None) _max_connections = deferred_config("azure.storage.max_connections", 0)
class _Container(object): class _Container(object):
def __init__(self, name, config, account_url): def __init__(self, name, config, account_url):
@ -2047,45 +2049,52 @@ class _AzureBlobServiceStorageDriver(_Driver):
max_single_put_size=self.MAX_SINGLE_PUT_SIZE, max_single_put_size=self.MAX_SINGLE_PUT_SIZE,
) )
@staticmethod
def _get_max_connections_dict(max_connections=None, key="max_connections"):
# must cast for deferred resolving
try:
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
except (AttributeError, TypeError):
return {}
return {key: int(max_connections)} if max_connections else {}
def create_blob_from_data( def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=None, self, container_name, object_name, blob_name, data, max_connections=None,
progress_callback=None, content_settings=None progress_callback=None, content_settings=None
): ):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_bytes( self.__blob_service.create_blob_from_bytes(
container_name, container_name,
object_name, object_name,
data, data,
max_connections=max_connections,
progress_callback=progress_callback, progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
) )
else: else:
client = self.__blob_service.get_blob_client(container_name, blob_name) client = self.__blob_service.get_blob_client(container_name, blob_name)
client.upload_blob( client.upload_blob(
data, overwrite=True, data, overwrite=True,
max_concurrency=max_connections,
content_settings=content_settings, content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
) )
def create_blob_from_path( def create_blob_from_path(
self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None
): ):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_path( self.__blob_service.create_blob_from_path(
container_name, container_name,
blob_name, blob_name,
path, path,
max_connections=max_connections,
content_settings=content_settings, content_settings=content_settings,
progress_callback=progress_callback, progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
) )
else: else:
self.create_blob_from_data( self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"), container_name, None, blob_name, open(path, "rb"),
max_connections=max_connections,
content_settings=content_settings, content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
) )
def delete_blob(self, container_name, blob_name): def delete_blob(self, container_name, blob_name):
@ -2131,19 +2140,20 @@ class _AzureBlobServiceStorageDriver(_Driver):
return client.download_blob().content_as_bytes() return client.download_blob().content_as_bytes()
def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None): def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
return self.__blob_service.get_blob_to_path( return self.__blob_service.get_blob_to_path(
container_name, container_name,
blob_name, blob_name,
path, path,
max_connections=max_connections,
progress_callback=progress_callback, progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
) )
else: else:
client = self.__blob_service.get_blob_client(container_name, blob_name) client = self.__blob_service.get_blob_client(container_name, blob_name)
with open(path, "wb") as file: with open(path, "wb") as file:
return client.download_blob(max_concurrency=max_connections).download_to_stream(file) return client.download_blob(
**self._get_max_connections_dict(max_connections, "max_concurrency")
).download_to_stream(file)
def is_legacy(self): def is_legacy(self):
return self.__legacy return self.__legacy
@ -2193,7 +2203,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().error("Failed uploading: %s" % ex) self.get_logger().error("Failed uploading: %s" % ex)
return False return False
def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs): def upload_object(
self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs
):
try: try:
from azure.common import AzureHttpError # noqa from azure.common import AzureHttpError # noqa
except ImportError: except ImportError:
@ -2202,7 +2214,6 @@ class _AzureBlobServiceStorageDriver(_Driver):
AzureHttpError = HttpResponseError # noqa AzureHttpError = HttpResponseError # noqa
blob_name = self._blob_name_from_object_path(object_name, container.name) blob_name = self._blob_name_from_object_path(object_name, container.name)
stream = None
try: try:
from azure.storage.blob import ContentSettings # noqa from azure.storage.blob import ContentSettings # noqa
@ -2219,9 +2230,6 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().error('Failed uploading (Azure error): %s' % ex) self.get_logger().error('Failed uploading (Azure error): %s' % ex)
except Exception as ex: except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex) self.get_logger().error('Failed uploading: %s' % ex)
finally:
if stream:
stream.close()
def list_container_objects(self, container, ex_prefix=None, **kwargs): def list_container_objects(self, container, ex_prefix=None, **kwargs):
return list(container.list_blobs(container_name=container.name, prefix=ex_prefix)) return list(container.list_blobs(container_name=container.name, prefix=ex_prefix))
@ -2264,7 +2272,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
else: else:
return blob return blob
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_): def download_object(
self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_
):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p)) self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p))