Fix task status comparison breaks when using a string

This commit is contained in:
allegroai 2023-06-21 08:54:01 +03:00
parent ab3c148de2
commit 8524874998
2 changed files with 14 additions and 14 deletions

View File

@ -769,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True): def publish(self, ignore_errors=True):
# type: (bool) -> () # type: (bool) -> ()
""" The signal that this task will be published """ """ The signal that this task will be published """
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)): if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
raise ValueError("Can't publish, Task is not stopped") raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors) resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse) assert isinstance(resp.response, tasks.PublishResponse)
@ -809,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
try: try:
res = self.send(tasks.GetByIdRequest(self.task_id)) res = self.send(tasks.GetByIdRequest(self.task_id))
task = res.response.task task = res.response.task
if task.status == Task.TaskStatusEnum.published: if task.status == self.TaskStatusEnum.published:
if raise_on_error: if raise_on_error:
raise self.DeleteError("Cannot delete published task {}".format(self.task_id)) raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
self.log.error("Cannot delete published task {}".format(self.task_id)) self.log.error("Cannot delete published task {}".format(self.task_id))
@ -2425,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: list of tuples (status, status_message, task_id) :return: list of tuples (status, status_message, task_id)
""" """
if cls._offline_mode: if cls._offline_mode:
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids] return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -2564,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we ae using forced update, make sure he task status is valid # Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if not kwargs.pop("force", False) and \ if not kwargs.pop("force", False) and \
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress): status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
# the exception being name/comment that we can always change. # the exception being name/comment that we can always change.
if kwargs and all( if kwargs and all(
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys() k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
@ -3017,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_task_status(cls, task_id): def _get_task_status(cls, task_id):
# type: (str) -> (Optional[str], Optional[str]) # type: (str) -> (Optional[str], Optional[str])
if cls._offline_mode: if cls._offline_mode:
return tasks.TaskStatusEnum.created, 'offline' return cls.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException # noinspection PyBroadException
try: try:

View File

@ -3449,8 +3449,8 @@ class Task(_Task):
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \ task_artifacts = task.data.execution.artifacts \
if hasattr(task.data.execution, 'artifacts') else None if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in ( if ((task._status in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
or task.output_models_id or (cls.archived_tag in task_tags) or task.output_models_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags) or (cls._development_tag not in task_tags)
or task_artifacts): or task_artifacts):
@ -4609,15 +4609,15 @@ class Task(_Task):
return False return False
stopped_statuses = ( stopped_statuses = (
str(tasks.TaskStatusEnum.stopped), cls.TaskStatusEnum.stopped,
str(tasks.TaskStatusEnum.published), cls.TaskStatusEnum.published,
str(tasks.TaskStatusEnum.publishing), cls.TaskStatusEnum.publishing,
str(tasks.TaskStatusEnum.closed), cls.TaskStatusEnum.closed,
str(tasks.TaskStatusEnum.failed), cls.TaskStatusEnum.failed,
str(tasks.TaskStatusEnum.completed), cls.TaskStatusEnum.completed,
) )
if str(task.status) not in stopped_statuses: if task.status not in stopped_statuses:
cls._send( cls._send(
cls._get_default_session(), cls._get_default_session(),
tasks.StoppedRequest( tasks.StoppedRequest(