External files: Improve performance (#962)

* External files: Improve performance
Add option for faster .add_external_files() option
Add option for parallel external file download

* Fix bug in add_external_files

* Remove max_workers_external_files parameter
Integrate PR feedback

* Integrate feedback

* Fix PEP suggestion
This commit is contained in:
john-zielke-snkeos 2023-04-24 15:21:08 +02:00 committed by GitHub
parent 71e1608ad8
commit a07f396f17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import calendar import calendar
import itertools
import json import json
import os import os
import shutil import shutil
@ -416,7 +417,7 @@ class Dataset(object):
self, self,
source_url, # type: Union[str, Sequence[str]] source_url, # type: Union[str, Sequence[str]]
wildcard=None, # type: Optional[Union[str, Sequence[str]]] wildcard=None, # type: Optional[Union[str, Sequence[str]]]
dataset_path=None, # type: Optional[str] dataset_path=None, # type: Optional[Union[str,Sequence[str]]]
recursive=True, # type: bool recursive=True, # type: bool
verbose=False, # type: bool verbose=False, # type: bool
max_workers=None # type: Optional[int] max_workers=None # type: Optional[int]
@ -458,14 +459,21 @@ class Dataset(object):
source_url_list = source_url if not isinstance(source_url, str) else [source_url] source_url_list = source_url if not isinstance(source_url, str) else [source_url]
max_workers = max_workers or psutil.cpu_count() max_workers = max_workers or psutil.cpu_count()
futures_ = [] futures_ = []
if isinstance(dataset_path, str) or dataset_path is None:
dataset_paths = itertools.repeat(dataset_path)
else:
if len(dataset_path) != len(source_url):
raise ValueError("dataset_path must be a string or a list of strings with the same length as source_url"
f" (received {len(dataset_path)} paths for {len(source_url)} source urls))")
dataset_paths = dataset_path
with ThreadPoolExecutor(max_workers=max_workers) as tp: with ThreadPoolExecutor(max_workers=max_workers) as tp:
for source_url_ in source_url_list: for source_url_, dataset_path_ in zip(source_url_list, dataset_paths):
futures_.append( futures_.append(
tp.submit( tp.submit(
self._add_external_files, self._add_external_files,
source_url_, source_url_,
wildcard=wildcard, wildcard=wildcard,
dataset_path=dataset_path, dataset_path=dataset_path_,
recursive=recursive, recursive=recursive,
verbose=verbose, verbose=verbose,
) )
@ -2194,12 +2202,12 @@ class Dataset(object):
max_workers=max_workers max_workers=max_workers
) )
self._download_external_files( self._download_external_files(
target_folder=target_folder, lock_target_folder=lock_target_folder target_folder=target_folder, lock_target_folder=lock_target_folder, max_workers=max_workers
) )
return local_folder return local_folder
def _download_external_files( def _download_external_files(
self, target_folder=None, lock_target_folder=False self, target_folder=None, lock_target_folder=False, max_workers=None
): ):
# (Union(Path, str), bool) -> None # (Union(Path, str), bool) -> None
""" """
@ -2211,6 +2219,7 @@ class Dataset(object):
:param target_folder: If provided use the specified target folder, default, auto generate from Dataset ID. :param target_folder: If provided use the specified target folder, default, auto generate from Dataset ID.
:param lock_target_folder: If True, local the target folder so the next cleanup will not delete :param lock_target_folder: If True, local the target folder so the next cleanup will not delete
Notice you should unlock it manually, or wait for the process to finish for auto unlocking. Notice you should unlock it manually, or wait for the process to finish for auto unlocking.
:param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading.
""" """
target_folder = ( target_folder = (
Path(target_folder) Path(target_folder)
@ -2227,15 +2236,14 @@ class Dataset(object):
ds = Dataset.get(dependency) ds = Dataset.get(dependency)
links.update(ds._dataset_link_entries) links.update(ds._dataset_link_entries)
links.update(self._dataset_link_entries) links.update(self._dataset_link_entries)
for relative_path, link in links.items(): def _download_link(link,target_path):
target_path = os.path.join(target_folder, relative_path)
if os.path.exists(target_path): if os.path.exists(target_path):
LoggerRoot.get_base_logger().info( LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format( "{} already exists. Skipping downloading {}".format(
target_path, link target_path, link
)
) )
) return
continue
ok = False ok = False
error = None error = None
try: try:
@ -2257,6 +2265,19 @@ class Dataset(object):
LoggerRoot.get_base_logger().info(log_string) LoggerRoot.get_base_logger().info(log_string)
else: else:
link.size = Path(target_path).stat().st_size link.size = Path(target_path).stat().st_size
if not max_workers:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
_download_link(link,target_path)
else:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
pool.submit(_download_link,link,target_path)
def _extract_dataset_archive( def _extract_dataset_archive(
self, self,