Allow controlling the number of threads used by StorageManager.download_folder() using the max_workers argument

This commit is contained in:
allegroai 2024-03-17 15:01:05 +02:00
parent 57614e68a5
commit 3531071fd5

View File

@ -363,8 +363,9 @@ class StorageManager(object):
overwrite=False, overwrite=False,
skip_zero_size_check=False, skip_zero_size_check=False,
silence_errors=False, silence_errors=False,
max_workers=None
): ):
# type: (str, Optional[str], Optional[str], bool, bool, bool) -> Optional[str] # type: (str, Optional[str], Optional[str], bool, bool, bool, Optional[int]) -> Optional[str]
""" """
Download remote folder recursively to the local machine, maintaining the sub folder structure Download remote folder recursively to the local machine, maintaining the sub folder structure
from the remote storage. from the remote storage.
@ -387,6 +388,11 @@ class StorageManager(object):
:param bool skip_zero_size_check: If True, no error will be raised for files with zero bytes size. :param bool skip_zero_size_check: If True, no error will be raised for files with zero bytes size.
:param bool silence_errors: If True, silence errors that might pop up when trying to download :param bool silence_errors: If True, silence errors that might pop up when trying to download
files stored remotely. Default False files stored remotely. Default False
:param int max_workers: If value is set to a number,
it will spawn the specified number of worker threads
to download the contents of the folder in parallel. Otherwise, if set to None, it will
internally use as many threads as there are
logical CPU cores in the system (this is default Python behavior). Default None
:return: Target local folder :return: Target local folder
""" """
@ -405,7 +411,7 @@ class StorageManager(object):
helper = StorageHelper.get(remote_url) helper = StorageHelper.get(remote_url)
results = [] results = []
with ThreadPool() as pool: with ThreadPool(processes=max_workers) as pool:
for path in helper.list(prefix=remote_url): for path in helper.list(prefix=remote_url):
remote_path = ( remote_path = (
str(Path(helper.base_url) / Path(path)) str(Path(helper.base_url) / Path(path))