diff --git a/clearml/datasets/dataset.py b/clearml/datasets/dataset.py index 4f7cf0db..ddf2e94e 100644 --- a/clearml/datasets/dataset.py +++ b/clearml/datasets/dataset.py @@ -123,6 +123,7 @@ class Dataset(object): __hyperparams_section = "Datasets" __datasets_runtime_prop = "datasets" __orig_datasets_runtime_prop_prefix = "orig_datasets" + __dataset_struct = "Dataset Struct" __preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int) __preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int) __preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int) @@ -2081,13 +2082,35 @@ class Dataset(object): self.update_changed_files(num_files_added=count - modified_count, num_files_modified=modified_count) return count - modified_count, modified_count + def _repair_dependency_graph(self): + """ + Repair dependency graph via the Dataset Struct configuration object. + Might happen for datasets with external files in old clearml versions + """ + try: + dataset_struct = self._task.get_configuration_object_as_dict(Dataset.__dataset_struct) + new_dependency_graph = {} + for dataset in dataset_struct.values(): + new_dependency_graph[dataset["job_id"]] = [dataset_struct[p]["job_id"] for p in dataset["parents"]] + self._dependency_graph = new_dependency_graph + except Exception as e: + LoggerRoot.get_base_logger().warning("Could not repair dependency graph. Error is: {}".format(e)) + def _update_dependency_graph(self): """ - Update the dependency graph based on the current self._dataset_file_entries state + Update the dependency graph based on the current self._dataset_file_entries + and self._dataset_link_entries states :return: """ # collect all dataset versions - used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values()) + used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values()) | set( + f.parent_dataset_id for f in self._dataset_link_entries.values() + ) + for dataset_id in used_dataset_versions: + if dataset_id not in self._dependency_graph and dataset_id != self._id: + self._repair_dependency_graph() + break + used_dataset_versions.add(self._id) current_parents = self._dependency_graph.get(self._id) or [] # remove parent versions we no longer need from the main version list @@ -2296,29 +2319,8 @@ class Dataset(object): Notice you should unlock it manually, or wait for the process to finish for auto unlocking. :param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading. """ - target_folder = ( - Path(target_folder) - if target_folder - else self._create_ds_target_folder( - lock_target_folder=lock_target_folder - )[0] - ).as_posix() - dependencies = self._get_dependencies_by_order( - include_unused=False, include_current=True - ) - links = {} - for dependency in dependencies: - ds = Dataset.get(dependency) - links.update(ds._dataset_link_entries) - links.update(self._dataset_link_entries) - def _download_link(link, target_path): if os.path.exists(target_path): - LoggerRoot.get_base_logger().info( - "{} already exists. Skipping downloading {}".format( - target_path, link - ) - ) return ok = False error = None @@ -2341,27 +2343,40 @@ class Dataset(object): LoggerRoot.get_base_logger().info(log_string) else: link.size = Path(target_path).stat().st_size - if not max_workers: - for relative_path, link in links.items(): - if not is_path_traversal(target_folder, relative_path): - target_path = os.path.join(target_folder, relative_path) - else: - LoggerRoot.get_base_logger().warning( - "Ignoring relative path `{}`: it must not traverse directories".format(relative_path) - ) - target_path = os.path.join(target_folder, os.path.basename(relative_path)) + + def _get_target_path(relative_path, target_folder): + if not is_path_traversal(target_folder, relative_path): + return os.path.join(target_folder, relative_path) + else: + LoggerRoot.get_base_logger().warning( + "Ignoring relative path `{}`: it must not traverse directories".format(relative_path) + ) + return os.path.join(target_folder, os.path.basename(relative_path)) + + def _submit_download_link(relative_path, link, target_folder, pool=None): + if link.parent_dataset_id != self.id: + return + target_path = _get_target_path(relative_path, target_folder) + if pool is None: _download_link(link, target_path) + else: + pool.submit(_download_link, link, target_path) + + target_folder = ( + Path(target_folder) + if target_folder + else self._create_ds_target_folder( + lock_target_folder=lock_target_folder + )[0] + ).as_posix() + + if not max_workers: + for relative_path, link in self._dataset_link_entries.items(): + _submit_download_link(relative_path, link, target_folder) else: with ThreadPoolExecutor(max_workers=max_workers) as pool: - for relative_path, link in links.items(): - if not is_path_traversal(target_folder, relative_path): - target_path = os.path.join(target_folder, relative_path) - else: - LoggerRoot.get_base_logger().warning( - "Ignoring relative path `{}`: it must not traverse directories".format(relative_path) - ) - target_path = os.path.join(target_folder, os.path.basename(relative_path)) - pool.submit(_download_link, link, target_path) + for relative_path, link in self._dataset_link_entries.items(): + _submit_download_link(relative_path, link, target_folder, pool=pool) def _extract_dataset_archive( self, @@ -2586,6 +2601,7 @@ class Dataset(object): :param include_current: If True include the current dataset ID as the last ID in the list :return: list of str representing the datasets id """ + self._update_dependency_graph() roots = [self._id] dependencies = [] # noinspection DuplicatedCode @@ -3027,7 +3043,7 @@ class Dataset(object): # fetch the parents of this version (task) based on what we have on the Task itself. # noinspection PyBroadException try: - dataset_version_node = task.get_configuration_object_as_dict("Dataset Struct") + dataset_version_node = task.get_configuration_object_as_dict(Dataset.__dataset_struct) # fine the one that is us for node in dataset_version_node.values(): if node["job_id"] != id_: @@ -3056,7 +3072,7 @@ class Dataset(object): dataset_struct[indices[id_]]["parents"] = [indices[p] for p in parents] # noinspection PyProtectedMember self._task._set_configuration( - name="Dataset Struct", + name=Dataset.__dataset_struct, description="Structure of the dataset", config_type="json", config_text=json.dumps(dataset_struct, indent=2), @@ -3234,7 +3250,8 @@ class Dataset(object): return None - errors = pool.map(copy_file, self._dataset_file_entries.values()) + errors = list(pool.map(copy_file, self._dataset_file_entries.values())) + errors.extend(list(pool.map(copy_file, self._dataset_link_entries.values()))) CacheManager.get_cache_manager(cache_context=self.__cache_context).unlock_cache_folder( ds_base_folder.as_posix())