Fix casting None to int fails uploads and permission checks

This commit is contained in:
allegroai 2022-12-29 13:18:39 +02:00
parent a3d44aa81f
commit c75c83c21d

View File

@ -602,7 +602,7 @@ class StorageHelper(object):
return self._get_object_size_bytes(obj)
def _get_object_size_bytes(self, obj):
# type: (object, bool) -> [int, None]
# type: (object) -> [int, None]
"""
Auxiliary function for `get_object_size_bytes`.
Get size of the remote object in bytes.
@ -1606,6 +1606,7 @@ class _Boto3Driver(_Driver):
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer
stream = _Stream(iterator)
extra_args = {}
try:
extra_args = {
'ContentType': get_file_mimetype(object_name)
@ -1630,7 +1631,7 @@ class _Boto3Driver(_Driver):
num_download_attempts=container.config.retries,
),
Callback=callback,
ExtraArgs=extra_args,
ExtraArgs=extra_args
)
except Exception as ex:
self.get_logger().error("Failed uploading: %s" % ex)
@ -1642,6 +1643,7 @@ class _Boto3Driver(_Driver):
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer
extra_args = {}
try:
extra_args = {
'ContentType': get_file_mimetype(object_name or file_path)
@ -1665,7 +1667,7 @@ class _Boto3Driver(_Driver):
use_threads=False, num_download_attempts=container.config.retries
),
Callback=callback,
ExtraArgs=extra_args,
ExtraArgs=extra_args
)
except Exception as ex:
self.get_logger().error("Failed uploading: %s" % ex)
@ -2006,7 +2008,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
scheme = "azure"
_containers = {}
_max_connections = deferred_config("azure.storage.max_connections", None)
_max_connections = deferred_config("azure.storage.max_connections", 0)
class _Container(object):
def __init__(self, name, config, account_url):
@ -2047,45 +2049,52 @@ class _AzureBlobServiceStorageDriver(_Driver):
max_single_put_size=self.MAX_SINGLE_PUT_SIZE,
)
@staticmethod
def _get_max_connections_dict(max_connections=None, key="max_connections"):
# must cast for deferred resolving
try:
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
except (AttributeError, TypeError):
return {}
return {key: int(max_connections)} if max_connections else {}
def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=None,
progress_callback=None, content_settings=None
):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy:
self.__blob_service.create_blob_from_bytes(
container_name,
object_name,
data,
max_connections=max_connections,
progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
client.upload_blob(
data, overwrite=True,
max_concurrency=max_connections,
content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
)
def create_blob_from_path(
self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None
):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy:
self.__blob_service.create_blob_from_path(
container_name,
blob_name,
path,
max_connections=max_connections,
content_settings=content_settings,
progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
)
else:
self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"),
max_connections=max_connections,
content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
)
def delete_blob(self, container_name, blob_name):
@ -2131,19 +2140,20 @@ class _AzureBlobServiceStorageDriver(_Driver):
return client.download_blob().content_as_bytes()
def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None):
max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy:
return self.__blob_service.get_blob_to_path(
container_name,
blob_name,
path,
max_connections=max_connections,
progress_callback=progress_callback,
**self._get_max_connections_dict(max_connections)
)
else:
client = self.__blob_service.get_blob_client(container_name, blob_name)
with open(path, "wb") as file:
return client.download_blob(max_concurrency=max_connections).download_to_stream(file)
return client.download_blob(
**self._get_max_connections_dict(max_connections, "max_concurrency")
).download_to_stream(file)
def is_legacy(self):
return self.__legacy
@ -2193,7 +2203,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().error("Failed uploading: %s" % ex)
return False
def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs):
def upload_object(
self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs
):
try:
from azure.common import AzureHttpError # noqa
except ImportError:
@ -2202,7 +2214,6 @@ class _AzureBlobServiceStorageDriver(_Driver):
AzureHttpError = HttpResponseError # noqa
blob_name = self._blob_name_from_object_path(object_name, container.name)
stream = None
try:
from azure.storage.blob import ContentSettings # noqa
@ -2219,9 +2230,6 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().error('Failed uploading (Azure error): %s' % ex)
except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex)
finally:
if stream:
stream.close()
def list_container_objects(self, container, ex_prefix=None, **kwargs):
return list(container.list_blobs(container_name=container.name, prefix=ex_prefix))
@ -2264,7 +2272,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
else:
return blob
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_):
def download_object(
self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_
):
p = Path(local_path)
if not overwrite_existing and p.is_file():
self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p))