Fix downloading datasets with multiple parents might not work (#1398)

This commit is contained in:
clearml 2025-04-18 16:09:13 +03:00
parent d9949d4c3e
commit 28deda0f3d

View File

@ -12,6 +12,7 @@ from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Union, Optional, Sequence, List, Dict, Mapping, Tuple, TYPE_CHECKING, Any from typing import Union, Optional, Sequence, List, Dict, Mapping, Tuple, TYPE_CHECKING, Any
from zipfile import ZIP_DEFLATED from zipfile import ZIP_DEFLATED
from collections import deque
import numpy import numpy
import psutil import psutil
@ -1029,6 +1030,7 @@ class Dataset(object):
:return: A base folder for the entire dataset :return: A base folder for the entire dataset
""" """
self._fix_dataset_files_parents()
assert self._id assert self._id
if Dataset.is_offline(): if Dataset.is_offline():
raise ValueError("Cannot get dataset local copy in offline mode.") raise ValueError("Cannot get dataset local copy in offline mode.")
@ -1443,17 +1445,21 @@ class Dataset(object):
""" """
Needed when someone removes and adds the same file -> parent data will be lost Needed when someone removes and adds the same file -> parent data will be lost
""" """
datasets = self._dependency_graph[self._id] self._repair_dependency_graph()
for ds_id in datasets: # use deque to avoid synchronized objects
dataset = self.get(dataset_id=ds_id) bfs_queue = deque()
for ( for parent in self._dependency_graph.get(self._id, []):
parent_file_key, bfs_queue.append(parent)
parent_file_value, while len(bfs_queue) > 0:
) in dataset._dataset_file_entries.items(): current_parent = Dataset.get(dataset_id=bfs_queue.popleft(), silence_alias_warnings=True)
if parent_file_key not in self._dataset_file_entries: for file_key, file_value in current_parent._dataset_file_entries.items():
continue if (
if parent_file_value.hash == self._dataset_file_entries[parent_file_key].hash: file_key in self._dataset_file_entries
self._dataset_file_entries[parent_file_key].parent_dataset_id = ds_id and file_value.hash == self._dataset_file_entries[file_key].hash
):
self._dataset_file_entries[file_key].parent_dataset_id = current_parent.id
for next_parent in self._dependency_graph.get(current_parent.id, []):
bfs_queue.append(next_parent)
def _get_total_size_compressed_parents(self) -> int: def _get_total_size_compressed_parents(self) -> int:
""" """
@ -1742,6 +1748,7 @@ class Dataset(object):
alias: Optional[str] = None, alias: Optional[str] = None,
overridable: bool = False, overridable: bool = False,
shallow_search: bool = False, shallow_search: bool = False,
silence_alias_warnings: bool = False,
**kwargs: Any, **kwargs: Any,
) -> "Dataset": ) -> "Dataset":
""" """
@ -1779,7 +1786,7 @@ class Dataset(object):
if not any([dataset_id, dataset_project, dataset_name, dataset_tags]): if not any([dataset_id, dataset_project, dataset_name, dataset_tags]):
raise ValueError("Dataset selection criteria not met. Didn't provide id/name/project/tags correctly.") raise ValueError("Dataset selection criteria not met. Didn't provide id/name/project/tags correctly.")
current_task = Task.current_task() current_task = Task.current_task()
if not alias and current_task: if not alias and current_task and not silence_alias_warnings:
LoggerRoot.get_base_logger().info( LoggerRoot.get_base_logger().info(
"Dataset.get() did not specify alias. Dataset information " "Dataset.get() did not specify alias. Dataset information "
"will not be automatically logged in ClearML Server." "will not be automatically logged in ClearML Server."