Add StorageManager.set_report_upload_chunk_size() and StorageManager.set_report_download_chunk_size() to set chunk size for upload and download

This commit is contained in:
allegroai 2022-12-22 21:57:15 +02:00
parent 7d6eff4858
commit ab73447cbd
3 changed files with 47 additions and 11 deletions

View File

@ -6,6 +6,9 @@ from ..config import config
class ProgressReport(object): class ProgressReport(object):
report_upload_chunk_size_mb = None
report_download_chunk_size_mb = None
def __init__(self, verbose, total_size, log, report_chunk_size_mb): def __init__(self, verbose, total_size, log, report_chunk_size_mb):
self.current_status_mb = 0. self.current_status_mb = 0.
self.last_reported = 0. self.last_reported = 0.
@ -34,9 +37,10 @@ class ProgressReport(object):
class UploadProgressReport(ProgressReport): class UploadProgressReport(ProgressReport):
def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=0): def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=None):
if not report_chunk_size_mb: report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
report_chunk_size_mb = int(config.get('storage.log.report_upload_chunk_size_mb', 0) or 5) else ProgressReport.report_upload_chunk_size_mb or \
int(config.get("storage.log.report_upload_chunk_size_mb", 5))
super(UploadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb) super(UploadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
self._filename = filename self._filename = filename
@ -74,10 +78,10 @@ class UploadProgressReport(ProgressReport):
class DownloadProgressReport(ProgressReport): class DownloadProgressReport(ProgressReport):
def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=0): def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=None):
if not report_chunk_size_mb: report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
report_chunk_size_mb = int(config.get('storage.log.report_download_chunk_size_mb', 0) or 5) else ProgressReport.report_download_chunk_size_mb or \
int(config.get("storage.log.report_download_chunk_size_mb", 5))
super(DownloadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb) super(DownloadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
self._remote_path = remote_path self._remote_path = remote_path

View File

@ -2051,7 +2051,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
self, container_name, object_name, blob_name, data, max_connections=None, self, container_name, object_name, blob_name, data, max_connections=None,
progress_callback=None, content_settings=None progress_callback=None, content_settings=None
): ):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_bytes( self.__blob_service.create_blob_from_bytes(
container_name, container_name,
@ -2071,7 +2071,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
def create_blob_from_path( def create_blob_from_path(
self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None
): ):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_path( self.__blob_service.create_blob_from_path(
container_name, container_name,
@ -2131,7 +2131,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
return client.download_blob().content_as_bytes() return client.download_blob().content_as_bytes()
def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None): def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections max_connections = max_connections or int(_AzureBlobServiceStorageDriver._max_connections)
if self.__legacy: if self.__legacy:
return self.__blob_service.get_blob_to_path( return self.__blob_service.get_blob_to_path(
container_name, container_name,

View File

@ -12,6 +12,7 @@ from six.moves.urllib.parse import urlparse
from pathlib2 import Path from pathlib2 import Path
from .cache import CacheManager from .cache import CacheManager
from .callbacks import ProgressReport
from .helper import StorageHelper from .helper import StorageHelper
from .util import encode_string_to_filename, safe_extract from .util import encode_string_to_filename, safe_extract
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
@ -463,7 +464,7 @@ class StorageManager(object):
@classmethod @classmethod
def get_metadata(cls, remote_url, return_full_path=False): def get_metadata(cls, remote_url, return_full_path=False):
# type: (str) -> Optional[dict] # type: (str, bool) -> Optional[dict]
""" """
Get the metadata of the a remote object. Get the metadata of the a remote object.
The metadata is a dict containing the following keys: `name`, `size`. The metadata is a dict containing the following keys: `name`, `size`.
@ -471,6 +472,7 @@ class StorageManager(object):
:param str remote_url: Source remote storage location, tree structure of `remote_url` will :param str remote_url: Source remote storage location, tree structure of `remote_url` will
be created under the target local_folder. Supports S3/GS/Azure, shared filesystem and http(s). be created under the target local_folder. Supports S3/GS/Azure, shared filesystem and http(s).
Example: 's3://bucket/data/' Example: 's3://bucket/data/'
:param return_full_path: True for returning a full path (with the base url)
:return: A dict containing the metadata of the remote object. In case of an error, `None` is returned :return: A dict containing the metadata of the remote object. In case of an error, `None` is returned
""" """
@ -482,3 +484,33 @@ class StorageManager(object):
if return_full_path and not metadata["name"].startswith(helper.base_url): if return_full_path and not metadata["name"].startswith(helper.base_url):
metadata["name"] = helper.base_url + ("/" if not helper.base_url.endswith("/") else "") + metadata["name"] metadata["name"] = helper.base_url + ("/" if not helper.base_url.endswith("/") else "") + metadata["name"]
return metadata return metadata
@classmethod
def set_report_upload_chunk_size(cls, chunk_size_mb):
# type: (int) -> ()
"""
Set the upload progress report chunk size (in MB). The chunk size
determines how often the progress reports are logged:
every time a chunk of data with a size greater than `chunk_size_mb`
is uploaded, log the report.
This function overwrites the `sdk.storage.log.report_upload_chunk_size_mb`
config entry
:param chunk_size_mb: The chunk size, in megabytes
"""
ProgressReport.report_upload_chunk_size_mb = int(chunk_size_mb)
@classmethod
def set_report_download_chunk_size(cls, chunk_size_mb):
# type: (int) -> ()
"""
Set the download progress report chunk size (in MB). The chunk size
determines how often the progress reports are logged:
every time a chunk of data with a size greater than `chunk_size_mb`
is downloaded, log the report.
This function overwrites the `sdk.storage.log.report_download_chunk_size_mb`
config entry
:param chunk_size_mb: The chunk size, in megabytes
"""
ProgressReport.report_download_chunk_size_mb = int(chunk_size_mb)