Fix TaskTypes/TaskStatusEnum Enum comparison

This commit is contained in:
allegroai 2020-05-22 10:30:06 +03:00
parent 63ca84a84f
commit 2d22efcead
2 changed files with 50 additions and 12 deletions

View File

@ -54,9 +54,31 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
class TaskTypes(Enum): class TaskTypes(Enum):
def __str__(self): def __str__(self):
return str(self.value) return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
training = 'training' training = 'training'
testing = 'testing' testing = 'testing'
class TaskStatusEnum(Enum):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
created = "created"
queued = "queued"
in_progress = "in_progress"
stopped = "stopped"
published = "published"
publishing = "publishing"
closed = "closed"
failed = "failed"
completed = "completed"
unknown = "unknown"
def __init__(self, session=None, task_id=None, log=None, project_name=None, def __init__(self, session=None, task_id=None, log=None, project_name=None,
task_name=None, task_type=TaskTypes.training, log_to_backend=True, task_name=None, task_type=TaskTypes.training, log_to_backend=True,
raise_on_validation_errors=True, force_create=False): raise_on_validation_errors=True, force_create=False):
@ -345,11 +367,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def status(self): def status(self):
""" """
The Task's status. To keep the Task updated, Trains reloads the Task information when this value The Task's status. To keep the Task updated.
is accessed. Trains reloads the Task status information only, when this value is accessed.
return str: TaskStatusEnum status
""" """
self.reload() return self.get_status()
return self._status
@property @property
def _status(self): def _status(self):
@ -826,6 +849,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._set_task_property("tags", tags) self._set_task_property("tags", tags)
self._edit(tags=self.data.tags) self._edit(tags=self.data.tags)
def get_system_tags(self):
return self._get_task_property("system_tags" if Session.check_min_api_version('2.3') else "tags")
def set_tags(self, tags): def set_tags(self, tags):
assert isinstance(tags, (list, tuple)) assert isinstance(tags, (list, tuple))
if not Session.check_min_api_version('2.3'): if not Session.check_min_api_version('2.3'):
@ -889,6 +915,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" """
return self._initial_iteration_offset return self._initial_iteration_offset
def get_status(self):
"""
Return The task status without refreshing the entire Task object object (only the status property)
TaskStatusEnum: ["created", "in_progress", "stopped", "closed", "failed", "completed",
"queued", "published", "publishing", "unknown"]
:return str: Task status as string (TaskStatusEnum)
"""
status = self._get_status()[0]
if self._data:
self._data.status = status
return str(status)
def _get_models(self, model_type='output'): def _get_models(self, model_type='output'):
model_type = model_type.lower().strip() model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input' assert model_type == 'output' or model_type == 'input'
@ -1029,7 +1069,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# clear all artifacts # clear all artifacts
execution['artifacts'] = [e for e in execution['artifacts'] if e.get('mode') == 'input'] execution['artifacts'] = [e for e in execution['artifacts'] if e.get('mode') == 'input']
if not tags and task.tags: if not hasattr(task, 'system_tags') and not tags and task.tags:
tags = [t for t in task.tags if t != cls._development_tag] tags = [t for t in task.tags if t != cls._development_tag]
req = tasks.CreateRequest( req = tasks.CreateRequest(
@ -1037,7 +1077,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
type=task.type, type=task.type,
input=task.input if hasattr(task, 'input') else {'view': {}}, input=task.input if hasattr(task, 'input') else {'view': {}},
tags=tags, tags=tags,
comment=comment or task.comment, comment=comment if comment is not None else task.comment,
parent=parent, parent=parent,
project=project if project else task.project, project=project if project else task.project,
output_dest=output_dest, output_dest=output_dest,

View File

@ -370,12 +370,10 @@ class Task(_Task):
# was not specified, keep legacy default value of TaskTypes.training # was not specified, keep legacy default value of TaskTypes.training
task_type = cls.TaskTypes.training task_type = cls.TaskTypes.training
elif isinstance(task_type, six.string_types): elif isinstance(task_type, six.string_types):
task_type_lookup = {'testing': cls.TaskTypes.testing, 'inference': cls.TaskTypes.testing, if task_type not in Task.TaskTypes.__members__:
'train': cls.TaskTypes.training, 'training': cls.TaskTypes.training,} raise ValueError("Task type '{}' not supported, options are: {}".format(
if task_type not in task_type_lookup: task_type, Task.TaskTypes.__members__.keys()))
raise ValueError("Task type '{}' not supported, options are: {}".format(task_type, task_type = Task.TaskTypes.__members__[task_type]
list(task_type_lookup.keys())))
task_type = task_type_lookup[task_type]
try: try:
if not running_remotely(): if not running_remotely():