Fix support for Dataset.create() argument use_current_task

Add dataset_tags argument to Dataset.create()
This commit is contained in:
allegroai 2021-05-12 15:41:04 +03:00
parent 4e4aab56a8
commit 228c17e44d

View File

@ -55,8 +55,8 @@ class Dataset(object):
__preview_max_file_entries = 15000
__preview_max_size = 5 * 1024 * 1024
def __init__(self, _private, task=None, dataset_project=None, dataset_name=None):
# type: (int, Optional[Task], Optional[str], Optional[str]) -> ()
def __init__(self, _private, task=None, dataset_project=None, dataset_name=None, dataset_tags=None):
# type: (int, Optional[Task], Optional[str], Optional[str], Optional[Sequence[str]]) -> ()
"""
Do not use directly! Use Dataset.create(...) or Dataset.get(...) instead.
"""
@ -73,11 +73,15 @@ class Dataset(object):
str(task.data.status) in ('created', 'in_progress'):
task.set_task_type(task_type=Task.TaskTypes.data_processing)
task.set_system_tags((task.get_system_tags() or []) + [self.__tag])
if dataset_tags:
task.set_tags((task.get_tags() or []) + list(dataset_tags))
else:
self._created_task = True
task = Task.create(
project_name=dataset_project, task_name=dataset_name, task_type=Task.TaskTypes.data_processing)
task.set_system_tags((task.get_system_tags() or []) + [self.__tag])
if dataset_tags:
task.set_tags((task.get_tags() or []) + list(dataset_tags))
task.mark_started()
# generate the script section
script = \
@ -104,6 +108,7 @@ class Dataset(object):
self._local_base_folder = None # type: Optional[Path]
# dirty flag, set True by any function call changing the dataset (regardless of weather it did anything)
self._dirty = False
self._using_current_task = False
@property
def id(self):
@ -392,6 +397,10 @@ class Dataset(object):
raise ValueError("Cannot finalize dataset, pending uploads. Call Dataset.upload(...)")
return False
status = self._task.get_status()
if status not in ('in_progress', 'created'):
raise ValueError("Cannot finalize dataset, status '{}' is not valid".format(status))
self._task.get_logger().report_text('Finalizing dataset', print_console=False)
# make sure we have no redundant parent versions
@ -402,8 +411,11 @@ class Dataset(object):
self._report_dataset_genealogy()
hashed_nodes = [self._get_dataset_id_hash(k) for k in self._dependency_graph.keys()]
self._task.comment = 'Dependencies: {}\n'.format(hashed_nodes)
self._task.close()
self._task.completed()
if self._using_current_task:
self._task.flush(wait_for_uploads=True)
else:
self._task.close()
self._task.completed()
if self._task_pinger:
self._task_pinger.unregister()
@ -640,8 +652,15 @@ class Dataset(object):
return self._task.output_uri or self._task.get_logger().get_default_upload_destination()
@classmethod
def create(cls, dataset_name, dataset_project=None, parent_datasets=None, use_current_task=False):
# type: (str, Optional[str], Optional[Sequence[Union[str, Dataset]]], bool) -> Dataset
def create(
cls,
dataset_name=None, # type: Optional[str]
dataset_project=None, # type: Optional[str]
dataset_tags=None, # type: Optional[Sequence[str]]
parent_datasets=None, # type: Optional[Sequence[Union[str, Dataset]]]
use_current_task=False # type: bool
):
# type: (...) -> Dataset
"""
Create a new dataset. Multiple dataset parents are supported.
Merging of parent datasets is done based on the order,
@ -650,6 +669,7 @@ class Dataset(object):
:param dataset_name: Naming the new dataset
:param dataset_project: Project containing the dataset.
If not specified, infer project name form parent datasets
:param dataset_tags: Optional, list of tags (strings) to attach to the newly created Dataset
:param parent_datasets: Expand a parent dataset by adding/removing files
:param use_current_task: False (default), a new Dataset task is created.
If True, the dataset is created on the current Task.
@ -660,7 +680,7 @@ class Dataset(object):
raise ValueError("Cannot inherit from a parent that was not finalized/closed")
# get project name
if not dataset_project:
if not dataset_project and not use_current_task:
if not parent_datasets:
raise ValueError("Missing dataset project name. Could not infer project name from parent dataset.")
# get project name from parent dataset
@ -675,7 +695,9 @@ class Dataset(object):
instance = cls(_private=cls.__private_magic,
dataset_project=dataset_project,
dataset_name=dataset_name,
dataset_tags=dataset_tags,
task=Task.current_task() if use_current_task else None)
instance._using_current_task = use_current_task
instance._task.get_logger().report_text('Dataset created', print_console=False)
instance._dataset_file_entries = dataset_file_entries
instance._dependency_graph = dependency_graph
@ -716,6 +738,7 @@ class Dataset(object):
# check if someone is using the datasets
if not force:
# todo: use Task runtime_properties
# noinspection PyProtectedMember
dependencies = Task._query_tasks(
system_tags=[cls.__tag],