Fix task status comparison breaks when using a string

This commit is contained in:
allegroai 2023-06-21 08:54:01 +03:00
parent ab3c148de2
commit 8524874998
2 changed files with 14 additions and 14 deletions

View File

@ -769,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal that this task will be published """
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)):
if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse)
@ -809,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
try:
res = self.send(tasks.GetByIdRequest(self.task_id))
task = res.response.task
if task.status == Task.TaskStatusEnum.published:
if task.status == self.TaskStatusEnum.published:
if raise_on_error:
raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
self.log.error("Cannot delete published task {}".format(self.task_id))
@ -2425,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: list of tuples (status, status_message, task_id)
"""
if cls._offline_mode:
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids]
return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
# noinspection PyBroadException
try:
@ -2564,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if not kwargs.pop("force", False) and \
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
# the exception being name/comment that we can always change.
if kwargs and all(
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
@ -3017,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_task_status(cls, task_id):
# type: (str) -> (Optional[str], Optional[str])
if cls._offline_mode:
return tasks.TaskStatusEnum.created, 'offline'
return cls.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException
try:

View File

@ -3449,8 +3449,8 @@ class Task(_Task):
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \
if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
if ((task._status in (
cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
or task.output_models_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags)
or task_artifacts):
@ -4609,15 +4609,15 @@ class Task(_Task):
return False
stopped_statuses = (
str(tasks.TaskStatusEnum.stopped),
str(tasks.TaskStatusEnum.published),
str(tasks.TaskStatusEnum.publishing),
str(tasks.TaskStatusEnum.closed),
str(tasks.TaskStatusEnum.failed),
str(tasks.TaskStatusEnum.completed),
cls.TaskStatusEnum.stopped,
cls.TaskStatusEnum.published,
cls.TaskStatusEnum.publishing,
cls.TaskStatusEnum.closed,
cls.TaskStatusEnum.failed,
cls.TaskStatusEnum.completed,
)
if str(task.status) not in stopped_statuses:
if task.status not in stopped_statuses:
cls._send(
cls._get_default_session(),
tasks.StoppedRequest(