Fix TaskTypes/TaskStatusEnum Enum comparison

This commit is contained in:
allegroai 2020-05-22 10:30:06 +03:00
parent 63ca84a84f
commit 2d22efcead
2 changed files with 50 additions and 12 deletions

View File

@ -54,9 +54,31 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
class TaskTypes(Enum):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
training = 'training'
testing = 'testing'
class TaskStatusEnum(Enum):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
created = "created"
queued = "queued"
in_progress = "in_progress"
stopped = "stopped"
published = "published"
publishing = "publishing"
closed = "closed"
failed = "failed"
completed = "completed"
unknown = "unknown"
def __init__(self, session=None, task_id=None, log=None, project_name=None,
task_name=None, task_type=TaskTypes.training, log_to_backend=True,
raise_on_validation_errors=True, force_create=False):
@ -345,11 +367,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property
def status(self):
"""
The Task's status. To keep the Task updated, Trains reloads the Task information when this value
is accessed.
The Task's status. To keep the Task updated.
Trains reloads the Task status information only, when this value is accessed.
return str: TaskStatusEnum status
"""
self.reload()
return self._status
return self.get_status()
@property
def _status(self):
@ -826,6 +849,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._set_task_property("tags", tags)
self._edit(tags=self.data.tags)
def get_system_tags(self):
return self._get_task_property("system_tags" if Session.check_min_api_version('2.3') else "tags")
def set_tags(self, tags):
assert isinstance(tags, (list, tuple))
if not Session.check_min_api_version('2.3'):
@ -889,6 +915,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
"""
return self._initial_iteration_offset
def get_status(self):
"""
Return The task status without refreshing the entire Task object object (only the status property)
TaskStatusEnum: ["created", "in_progress", "stopped", "closed", "failed", "completed",
"queued", "published", "publishing", "unknown"]
:return str: Task status as string (TaskStatusEnum)
"""
status = self._get_status()[0]
if self._data:
self._data.status = status
return str(status)
def _get_models(self, model_type='output'):
model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input'
@ -1029,7 +1069,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# clear all artifacts
execution['artifacts'] = [e for e in execution['artifacts'] if e.get('mode') == 'input']
if not tags and task.tags:
if not hasattr(task, 'system_tags') and not tags and task.tags:
tags = [t for t in task.tags if t != cls._development_tag]
req = tasks.CreateRequest(
@ -1037,7 +1077,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
type=task.type,
input=task.input if hasattr(task, 'input') else {'view': {}},
tags=tags,
comment=comment or task.comment,
comment=comment if comment is not None else task.comment,
parent=parent,
project=project if project else task.project,
output_dest=output_dest,

View File

@ -370,12 +370,10 @@ class Task(_Task):
# was not specified, keep legacy default value of TaskTypes.training
task_type = cls.TaskTypes.training
elif isinstance(task_type, six.string_types):
task_type_lookup = {'testing': cls.TaskTypes.testing, 'inference': cls.TaskTypes.testing,
'train': cls.TaskTypes.training, 'training': cls.TaskTypes.training,}
if task_type not in task_type_lookup:
raise ValueError("Task type '{}' not supported, options are: {}".format(task_type,
list(task_type_lookup.keys())))
task_type = task_type_lookup[task_type]
if task_type not in Task.TaskTypes.__members__:
raise ValueError("Task type '{}' not supported, options are: {}".format(
task_type, Task.TaskTypes.__members__.keys()))
task_type = Task.TaskTypes.__members__[task_type]
try:
if not running_remotely():