Fix dataset with external links will not reuse downloaded data from parents

This commit is contained in:
allegroai 2024-07-29 17:36:02 +03:00
parent d826e9806e
commit 2b66ee663e

View File

@ -123,6 +123,7 @@ class Dataset(object):
__hyperparams_section = "Datasets"
__datasets_runtime_prop = "datasets"
__orig_datasets_runtime_prop_prefix = "orig_datasets"
__dataset_struct = "Dataset Struct"
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
__preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int)
__preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int)
@ -2081,13 +2082,35 @@ class Dataset(object):
self.update_changed_files(num_files_added=count - modified_count, num_files_modified=modified_count)
return count - modified_count, modified_count
def _repair_dependency_graph(self):
"""
Repair dependency graph via the Dataset Struct configuration object.
Might happen for datasets with external files in old clearml versions
"""
try:
dataset_struct = self._task.get_configuration_object_as_dict(Dataset.__dataset_struct)
new_dependency_graph = {}
for dataset in dataset_struct.values():
new_dependency_graph[dataset["job_id"]] = [dataset_struct[p]["job_id"] for p in dataset["parents"]]
self._dependency_graph = new_dependency_graph
except Exception as e:
LoggerRoot.get_base_logger().warning("Could not repair dependency graph. Error is: {}".format(e))
def _update_dependency_graph(self):
"""
Update the dependency graph based on the current self._dataset_file_entries state
Update the dependency graph based on the current self._dataset_file_entries
and self._dataset_link_entries states
:return:
"""
# collect all dataset versions
used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values())
used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values()) | set(
f.parent_dataset_id for f in self._dataset_link_entries.values()
)
for dataset_id in used_dataset_versions:
if dataset_id not in self._dependency_graph and dataset_id != self._id:
self._repair_dependency_graph()
break
used_dataset_versions.add(self._id)
current_parents = self._dependency_graph.get(self._id) or []
# remove parent versions we no longer need from the main version list
@ -2296,29 +2319,8 @@ class Dataset(object):
Notice you should unlock it manually, or wait for the process to finish for auto unlocking.
:param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading.
"""
target_folder = (
Path(target_folder)
if target_folder
else self._create_ds_target_folder(
lock_target_folder=lock_target_folder
)[0]
).as_posix()
dependencies = self._get_dependencies_by_order(
include_unused=False, include_current=True
)
links = {}
for dependency in dependencies:
ds = Dataset.get(dependency)
links.update(ds._dataset_link_entries)
links.update(self._dataset_link_entries)
def _download_link(link, target_path):
if os.path.exists(target_path):
LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format(
target_path, link
)
)
return
ok = False
error = None
@ -2341,28 +2343,41 @@ class Dataset(object):
LoggerRoot.get_base_logger().info(log_string)
else:
link.size = Path(target_path).stat().st_size
if not max_workers:
for relative_path, link in links.items():
def _get_target_path(relative_path, target_folder):
if not is_path_traversal(target_folder, relative_path):
target_path = os.path.join(target_folder, relative_path)
return os.path.join(target_folder, relative_path)
else:
LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
)
target_path = os.path.join(target_folder, os.path.basename(relative_path))
return os.path.join(target_folder, os.path.basename(relative_path))
def _submit_download_link(relative_path, link, target_folder, pool=None):
if link.parent_dataset_id != self.id:
return
target_path = _get_target_path(relative_path, target_folder)
if pool is None:
_download_link(link, target_path)
else:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items():
if not is_path_traversal(target_folder, relative_path):
target_path = os.path.join(target_folder, relative_path)
else:
LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
)
target_path = os.path.join(target_folder, os.path.basename(relative_path))
pool.submit(_download_link, link, target_path)
target_folder = (
Path(target_folder)
if target_folder
else self._create_ds_target_folder(
lock_target_folder=lock_target_folder
)[0]
).as_posix()
if not max_workers:
for relative_path, link in self._dataset_link_entries.items():
_submit_download_link(relative_path, link, target_folder)
else:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in self._dataset_link_entries.items():
_submit_download_link(relative_path, link, target_folder, pool=pool)
def _extract_dataset_archive(
self,
force=False,
@ -2586,6 +2601,7 @@ class Dataset(object):
:param include_current: If True include the current dataset ID as the last ID in the list
:return: list of str representing the datasets id
"""
self._update_dependency_graph()
roots = [self._id]
dependencies = []
# noinspection DuplicatedCode
@ -3027,7 +3043,7 @@ class Dataset(object):
# fetch the parents of this version (task) based on what we have on the Task itself.
# noinspection PyBroadException
try:
dataset_version_node = task.get_configuration_object_as_dict("Dataset Struct")
dataset_version_node = task.get_configuration_object_as_dict(Dataset.__dataset_struct)
# fine the one that is us
for node in dataset_version_node.values():
if node["job_id"] != id_:
@ -3056,7 +3072,7 @@ class Dataset(object):
dataset_struct[indices[id_]]["parents"] = [indices[p] for p in parents]
# noinspection PyProtectedMember
self._task._set_configuration(
name="Dataset Struct",
name=Dataset.__dataset_struct,
description="Structure of the dataset",
config_type="json",
config_text=json.dumps(dataset_struct, indent=2),
@ -3234,7 +3250,8 @@ class Dataset(object):
return None
errors = pool.map(copy_file, self._dataset_file_entries.values())
errors = list(pool.map(copy_file, self._dataset_file_entries.values()))
errors.extend(list(pool.map(copy_file, self._dataset_link_entries.values())))
CacheManager.get_cache_manager(cache_context=self.__cache_context).unlock_cache_folder(
ds_base_folder.as_posix())