diff --git a/clearml/datasets/dataset.py b/clearml/datasets/dataset.py index 65841b29..5318bd5d 100644 --- a/clearml/datasets/dataset.py +++ b/clearml/datasets/dataset.py @@ -1,4 +1,5 @@ import calendar +import itertools import json import os import shutil @@ -416,7 +417,7 @@ class Dataset(object): self, source_url, # type: Union[str, Sequence[str]] wildcard=None, # type: Optional[Union[str, Sequence[str]]] - dataset_path=None, # type: Optional[str] + dataset_path=None, # type: Optional[Union[str,Sequence[str]]] recursive=True, # type: bool verbose=False, # type: bool max_workers=None # type: Optional[int] @@ -458,14 +459,21 @@ class Dataset(object): source_url_list = source_url if not isinstance(source_url, str) else [source_url] max_workers = max_workers or psutil.cpu_count() futures_ = [] + if isinstance(dataset_path, str) or dataset_path is None: + dataset_paths = itertools.repeat(dataset_path) + else: + if len(dataset_path) != len(source_url): + raise ValueError("dataset_path must be a string or a list of strings with the same length as source_url" + f" (received {len(dataset_path)} paths for {len(source_url)} source urls))") + dataset_paths = dataset_path with ThreadPoolExecutor(max_workers=max_workers) as tp: - for source_url_ in source_url_list: + for source_url_, dataset_path_ in zip(source_url_list, dataset_paths): futures_.append( tp.submit( self._add_external_files, source_url_, wildcard=wildcard, - dataset_path=dataset_path, + dataset_path=dataset_path_, recursive=recursive, verbose=verbose, ) @@ -2194,12 +2202,12 @@ class Dataset(object): max_workers=max_workers ) self._download_external_files( - target_folder=target_folder, lock_target_folder=lock_target_folder + target_folder=target_folder, lock_target_folder=lock_target_folder, max_workers=max_workers ) return local_folder def _download_external_files( - self, target_folder=None, lock_target_folder=False + self, target_folder=None, lock_target_folder=False, max_workers=None ): # (Union(Path, str), bool) -> None """ @@ -2211,6 +2219,7 @@ class Dataset(object): :param target_folder: If provided use the specified target folder, default, auto generate from Dataset ID. :param lock_target_folder: If True, local the target folder so the next cleanup will not delete Notice you should unlock it manually, or wait for the process to finish for auto unlocking. + :param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading. """ target_folder = ( Path(target_folder) @@ -2227,15 +2236,14 @@ class Dataset(object): ds = Dataset.get(dependency) links.update(ds._dataset_link_entries) links.update(self._dataset_link_entries) - for relative_path, link in links.items(): - target_path = os.path.join(target_folder, relative_path) + def _download_link(link,target_path): if os.path.exists(target_path): - LoggerRoot.get_base_logger().info( - "{} already exists. Skipping downloading {}".format( - target_path, link + LoggerRoot.get_base_logger().info( + "{} already exists. Skipping downloading {}".format( + target_path, link + ) ) - ) - continue + return ok = False error = None try: @@ -2257,6 +2265,19 @@ class Dataset(object): LoggerRoot.get_base_logger().info(log_string) else: link.size = Path(target_path).stat().st_size + if not max_workers: + for relative_path, link in links.items(): + target_path = os.path.join(target_folder, relative_path) + _download_link(link,target_path) + else: + with ThreadPoolExecutor(max_workers=max_workers) as pool: + for relative_path, link in links.items(): + target_path = os.path.join(target_folder, relative_path) + pool.submit(_download_link,link,target_path) + + + + def _extract_dataset_archive( self,