diff --git a/trains/task.py b/trains/task.py index 1bf606b7..c02521c0 100644 --- a/trains/task.py +++ b/trains/task.py @@ -13,7 +13,7 @@ try: except ImportError: from collections import Callable, Sequence as CollectionsSequence -from typing import Optional, Union, Mapping, Sequence, Any, Dict, TYPE_CHECKING +from typing import Optional, Union, Mapping, Sequence, Any, Dict, TYPE_CHECKING, Iterable import psutil import six @@ -2484,3 +2484,22 @@ class Task(_Task): exit(0) return + + def wait_for_status(self, status=(tasks.TaskStatusEnum.completed), + raise_on_status=(tasks.TaskStatusEnum.failed), + check_interval_sec=60): + # type: (Iterable[tasks.TaskStatusEnum], Iterable[tasks.TaskStatusEnum], int) -> () + """ + Wait for a task to reach a defined status. + + :param status: Status to wait for. Defaults to ('completed') + :param raise_on_status: Raise RuntimeError if the status of the tasks matches one of these values. + Defaults to ('failed'). + :param check_interval_sec: Interval in seconds between two checks. Defaults to 60 seconds. + :raise: RuntimeError if the status is one of {raise_on_status}. + """ + while self.status not in status or self.status not in raise_on_status: + time.sleep(check_interval_sec) + + if self.status in raise_on_status: + raise RuntimeError("Task {} has status: {}.".format(self.task_id, self.status))