diff --git a/clearml/datasets/dataset.py b/clearml/datasets/dataset.py index 091243aa..e229268f 100644 --- a/clearml/datasets/dataset.py +++ b/clearml/datasets/dataset.py @@ -55,8 +55,8 @@ class Dataset(object): __preview_max_file_entries = 15000 __preview_max_size = 5 * 1024 * 1024 - def __init__(self, _private, task=None, dataset_project=None, dataset_name=None): - # type: (int, Optional[Task], Optional[str], Optional[str]) -> () + def __init__(self, _private, task=None, dataset_project=None, dataset_name=None, dataset_tags=None): + # type: (int, Optional[Task], Optional[str], Optional[str], Optional[Sequence[str]]) -> () """ Do not use directly! Use Dataset.create(...) or Dataset.get(...) instead. """ @@ -73,11 +73,15 @@ class Dataset(object): str(task.data.status) in ('created', 'in_progress'): task.set_task_type(task_type=Task.TaskTypes.data_processing) task.set_system_tags((task.get_system_tags() or []) + [self.__tag]) + if dataset_tags: + task.set_tags((task.get_tags() or []) + list(dataset_tags)) else: self._created_task = True task = Task.create( project_name=dataset_project, task_name=dataset_name, task_type=Task.TaskTypes.data_processing) task.set_system_tags((task.get_system_tags() or []) + [self.__tag]) + if dataset_tags: + task.set_tags((task.get_tags() or []) + list(dataset_tags)) task.mark_started() # generate the script section script = \ @@ -104,6 +108,7 @@ class Dataset(object): self._local_base_folder = None # type: Optional[Path] # dirty flag, set True by any function call changing the dataset (regardless of weather it did anything) self._dirty = False + self._using_current_task = False @property def id(self): @@ -392,6 +397,10 @@ class Dataset(object): raise ValueError("Cannot finalize dataset, pending uploads. Call Dataset.upload(...)") return False + status = self._task.get_status() + if status not in ('in_progress', 'created'): + raise ValueError("Cannot finalize dataset, status '{}' is not valid".format(status)) + self._task.get_logger().report_text('Finalizing dataset', print_console=False) # make sure we have no redundant parent versions @@ -402,8 +411,11 @@ class Dataset(object): self._report_dataset_genealogy() hashed_nodes = [self._get_dataset_id_hash(k) for k in self._dependency_graph.keys()] self._task.comment = 'Dependencies: {}\n'.format(hashed_nodes) - self._task.close() - self._task.completed() + if self._using_current_task: + self._task.flush(wait_for_uploads=True) + else: + self._task.close() + self._task.completed() if self._task_pinger: self._task_pinger.unregister() @@ -640,8 +652,15 @@ class Dataset(object): return self._task.output_uri or self._task.get_logger().get_default_upload_destination() @classmethod - def create(cls, dataset_name, dataset_project=None, parent_datasets=None, use_current_task=False): - # type: (str, Optional[str], Optional[Sequence[Union[str, Dataset]]], bool) -> Dataset + def create( + cls, + dataset_name=None, # type: Optional[str] + dataset_project=None, # type: Optional[str] + dataset_tags=None, # type: Optional[Sequence[str]] + parent_datasets=None, # type: Optional[Sequence[Union[str, Dataset]]] + use_current_task=False # type: bool + ): + # type: (...) -> Dataset """ Create a new dataset. Multiple dataset parents are supported. Merging of parent datasets is done based on the order, @@ -650,6 +669,7 @@ class Dataset(object): :param dataset_name: Naming the new dataset :param dataset_project: Project containing the dataset. If not specified, infer project name form parent datasets + :param dataset_tags: Optional, list of tags (strings) to attach to the newly created Dataset :param parent_datasets: Expand a parent dataset by adding/removing files :param use_current_task: False (default), a new Dataset task is created. If True, the dataset is created on the current Task. @@ -660,7 +680,7 @@ class Dataset(object): raise ValueError("Cannot inherit from a parent that was not finalized/closed") # get project name - if not dataset_project: + if not dataset_project and not use_current_task: if not parent_datasets: raise ValueError("Missing dataset project name. Could not infer project name from parent dataset.") # get project name from parent dataset @@ -675,7 +695,9 @@ class Dataset(object): instance = cls(_private=cls.__private_magic, dataset_project=dataset_project, dataset_name=dataset_name, + dataset_tags=dataset_tags, task=Task.current_task() if use_current_task else None) + instance._using_current_task = use_current_task instance._task.get_logger().report_text('Dataset created', print_console=False) instance._dataset_file_entries = dataset_file_entries instance._dependency_graph = dependency_graph @@ -716,6 +738,7 @@ class Dataset(object): # check if someone is using the datasets if not force: + # todo: use Task runtime_properties # noinspection PyProtectedMember dependencies = Task._query_tasks( system_tags=[cls.__tag],