Fix dataset with external links will not reuse downloaded data from parents

This commit is contained in:
allegroai 2024-07-29 17:36:02 +03:00
parent d826e9806e
commit 2b66ee663e

View File

@ -123,6 +123,7 @@ class Dataset(object):
__hyperparams_section = "Datasets" __hyperparams_section = "Datasets"
__datasets_runtime_prop = "datasets" __datasets_runtime_prop = "datasets"
__orig_datasets_runtime_prop_prefix = "orig_datasets" __orig_datasets_runtime_prop_prefix = "orig_datasets"
__dataset_struct = "Dataset Struct"
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int) __preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
__preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int) __preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int)
__preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int) __preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int)
@ -2081,13 +2082,35 @@ class Dataset(object):
self.update_changed_files(num_files_added=count - modified_count, num_files_modified=modified_count) self.update_changed_files(num_files_added=count - modified_count, num_files_modified=modified_count)
return count - modified_count, modified_count return count - modified_count, modified_count
def _repair_dependency_graph(self):
"""
Repair dependency graph via the Dataset Struct configuration object.
Might happen for datasets with external files in old clearml versions
"""
try:
dataset_struct = self._task.get_configuration_object_as_dict(Dataset.__dataset_struct)
new_dependency_graph = {}
for dataset in dataset_struct.values():
new_dependency_graph[dataset["job_id"]] = [dataset_struct[p]["job_id"] for p in dataset["parents"]]
self._dependency_graph = new_dependency_graph
except Exception as e:
LoggerRoot.get_base_logger().warning("Could not repair dependency graph. Error is: {}".format(e))
def _update_dependency_graph(self): def _update_dependency_graph(self):
""" """
Update the dependency graph based on the current self._dataset_file_entries state Update the dependency graph based on the current self._dataset_file_entries
and self._dataset_link_entries states
:return: :return:
""" """
# collect all dataset versions # collect all dataset versions
used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values()) used_dataset_versions = set(f.parent_dataset_id for f in self._dataset_file_entries.values()) | set(
f.parent_dataset_id for f in self._dataset_link_entries.values()
)
for dataset_id in used_dataset_versions:
if dataset_id not in self._dependency_graph and dataset_id != self._id:
self._repair_dependency_graph()
break
used_dataset_versions.add(self._id) used_dataset_versions.add(self._id)
current_parents = self._dependency_graph.get(self._id) or [] current_parents = self._dependency_graph.get(self._id) or []
# remove parent versions we no longer need from the main version list # remove parent versions we no longer need from the main version list
@ -2296,29 +2319,8 @@ class Dataset(object):
Notice you should unlock it manually, or wait for the process to finish for auto unlocking. Notice you should unlock it manually, or wait for the process to finish for auto unlocking.
:param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading. :param max_workers: Number of threads to be spawned when getting dataset files. Defaults to no multi-threading.
""" """
target_folder = (
Path(target_folder)
if target_folder
else self._create_ds_target_folder(
lock_target_folder=lock_target_folder
)[0]
).as_posix()
dependencies = self._get_dependencies_by_order(
include_unused=False, include_current=True
)
links = {}
for dependency in dependencies:
ds = Dataset.get(dependency)
links.update(ds._dataset_link_entries)
links.update(self._dataset_link_entries)
def _download_link(link, target_path): def _download_link(link, target_path):
if os.path.exists(target_path): if os.path.exists(target_path):
LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format(
target_path, link
)
)
return return
ok = False ok = False
error = None error = None
@ -2341,27 +2343,40 @@ class Dataset(object):
LoggerRoot.get_base_logger().info(log_string) LoggerRoot.get_base_logger().info(log_string)
else: else:
link.size = Path(target_path).stat().st_size link.size = Path(target_path).stat().st_size
if not max_workers:
for relative_path, link in links.items(): def _get_target_path(relative_path, target_folder):
if not is_path_traversal(target_folder, relative_path): if not is_path_traversal(target_folder, relative_path):
target_path = os.path.join(target_folder, relative_path) return os.path.join(target_folder, relative_path)
else: else:
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path) "Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
) )
target_path = os.path.join(target_folder, os.path.basename(relative_path)) return os.path.join(target_folder, os.path.basename(relative_path))
def _submit_download_link(relative_path, link, target_folder, pool=None):
if link.parent_dataset_id != self.id:
return
target_path = _get_target_path(relative_path, target_folder)
if pool is None:
_download_link(link, target_path) _download_link(link, target_path)
else:
pool.submit(_download_link, link, target_path)
target_folder = (
Path(target_folder)
if target_folder
else self._create_ds_target_folder(
lock_target_folder=lock_target_folder
)[0]
).as_posix()
if not max_workers:
for relative_path, link in self._dataset_link_entries.items():
_submit_download_link(relative_path, link, target_folder)
else: else:
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items(): for relative_path, link in self._dataset_link_entries.items():
if not is_path_traversal(target_folder, relative_path): _submit_download_link(relative_path, link, target_folder, pool=pool)
target_path = os.path.join(target_folder, relative_path)
else:
LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
)
target_path = os.path.join(target_folder, os.path.basename(relative_path))
pool.submit(_download_link, link, target_path)
def _extract_dataset_archive( def _extract_dataset_archive(
self, self,
@ -2586,6 +2601,7 @@ class Dataset(object):
:param include_current: If True include the current dataset ID as the last ID in the list :param include_current: If True include the current dataset ID as the last ID in the list
:return: list of str representing the datasets id :return: list of str representing the datasets id
""" """
self._update_dependency_graph()
roots = [self._id] roots = [self._id]
dependencies = [] dependencies = []
# noinspection DuplicatedCode # noinspection DuplicatedCode
@ -3027,7 +3043,7 @@ class Dataset(object):
# fetch the parents of this version (task) based on what we have on the Task itself. # fetch the parents of this version (task) based on what we have on the Task itself.
# noinspection PyBroadException # noinspection PyBroadException
try: try:
dataset_version_node = task.get_configuration_object_as_dict("Dataset Struct") dataset_version_node = task.get_configuration_object_as_dict(Dataset.__dataset_struct)
# fine the one that is us # fine the one that is us
for node in dataset_version_node.values(): for node in dataset_version_node.values():
if node["job_id"] != id_: if node["job_id"] != id_:
@ -3056,7 +3072,7 @@ class Dataset(object):
dataset_struct[indices[id_]]["parents"] = [indices[p] for p in parents] dataset_struct[indices[id_]]["parents"] = [indices[p] for p in parents]
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._set_configuration( self._task._set_configuration(
name="Dataset Struct", name=Dataset.__dataset_struct,
description="Structure of the dataset", description="Structure of the dataset",
config_type="json", config_type="json",
config_text=json.dumps(dataset_struct, indent=2), config_text=json.dumps(dataset_struct, indent=2),
@ -3234,7 +3250,8 @@ class Dataset(object):
return None return None
errors = pool.map(copy_file, self._dataset_file_entries.values()) errors = list(pool.map(copy_file, self._dataset_file_entries.values()))
errors.extend(list(pool.map(copy_file, self._dataset_link_entries.values())))
CacheManager.get_cache_manager(cache_context=self.__cache_context).unlock_cache_folder( CacheManager.get_cache_manager(cache_context=self.__cache_context).unlock_cache_folder(
ds_base_folder.as_posix()) ds_base_folder.as_posix())