Fix retry_on_failure callback does nothing when specified on PipelineController.add_step

This commit is contained in:
allegroai 2022-10-14 10:14:29 +03:00
parent a0e19d833d
commit d497869a1f
2 changed files with 78 additions and 34 deletions

View File

@ -49,6 +49,7 @@ class PipelineController(object):
_pipeline_section = 'pipeline'
_pipeline_step_ref = 'pipeline'
_runtime_property_hash = '_pipeline_hash'
_relaunch_status_message = "Relaunching pipeline step..."
_reserved_pipeline_names = (_pipeline_step_ref, )
_task_project_lookup = {}
_clearml_job_class = ClearmlJob
@ -61,6 +62,7 @@ class PipelineController(object):
_add_to_evaluated_return_values = {} # TID: bool
_retries = {} # Node.name: int
_retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa
_final_failure = {} # Node.name: bool
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@ -1657,6 +1659,23 @@ class PipelineController(object):
# return False if we did not cover all the nodes
return not bool(set(self._nodes.keys()) - visited)
def _relaunch_node(self, node):
if not node.job:
getLogger("clearml.automation.controller").warning(
"Could not relaunch node {} (job object is missing)".format(node.name)
)
return
self._retries[node.name] = self._retries.get(node.name, 0) + 1
getLogger("clearml.automation.controller").warning(
"Node '{}' failed. Retrying... (this is retry number {})".format(node.name, self._retries[node.name])
)
node.job.task.mark_stopped(force=True, status_message=self._relaunch_status_message)
node.job.task.set_progress(0)
node.job.task.get_logger().report_text(
"\nNode '{}' failed. Retrying... (this is retry number {})\n".format(node.name, self._retries[node.name])
)
node.job.launch(queue_name=node.queue or self._default_execution_queue)
def _launch_node(self, node):
# type: (PipelineController.Node) -> ()
"""
@ -1909,6 +1928,37 @@ class PipelineController(object):
return table_values
def _call_retries_callback(self, node):
# if this functions returns True, we should relaunch the node
# if False, don't relaunch
if node.name not in self._retries_callbacks:
return False
try:
return self._retries_callbacks[node.name](self, node, self._retries.get(node.name, 0))
except Exception as e:
getLogger("clearml.automation.controller").warning(
"Failed calling the retry callback for node '{}'. Error is '{}'".format(node.name, e)
)
return False
@classmethod
def _wait_for_node(cls, node):
pool_period = 5.0 if cls._debug_execute_step_process else 20.0
while True:
node.job.wait(pool_period=pool_period, aborted_nonresponsive_as_running=True)
job_status = str(node.job.status(force=True))
if (
(
job_status == str(Task.TaskStatusEnum.stopped)
and node.job.status_message() == cls._relaunch_status_message
)
or (job_status == str(Task.TaskStatusEnum.failed) and not cls._final_failure.get(node.name))
or not node.job.is_stopped()
):
sleep(pool_period)
else:
break
@classmethod
def _get_node_color(cls, node):
# type (self.Mode) -> str
@ -2051,15 +2101,12 @@ class PipelineController(object):
continue
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
node_failed = node.job.is_failed()
if node_failed and \
self._retry_on_failure_callback(self, node, self._retries.get(node.name, 0)):
self._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node.name))
node.job = None
node.executed = None
self._running_nodes.remove(j)
self._retries[node.name] = self._retries.get(node.name, 0) + 1
continue
if node_failed:
if self._call_retries_callback(node):
self._relaunch_node(node)
continue
else:
self._final_failure[node.name] = True
completed_jobs.append(j)
node.executed = node.job.task_id() if not node_failed else False
@ -2718,6 +2765,12 @@ class PipelineDecorator(PipelineController):
continue
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
node_failed = node.job.is_failed()
if node_failed:
if self._call_retries_callback(node):
self._relaunch_node(node)
continue
else:
self._final_failure[node.name] = True
completed_jobs.append(j)
node.executed = node.job.task_id() if not node_failed else False
if j in launched_nodes:
@ -3218,7 +3271,7 @@ class PipelineDecorator(PipelineController):
# The actual launch is a bit slow, we run it in the background
launch_thread = Thread(
target=cls._component_launch_with_failover,
target=cls._component_launch,
args=(_node_name, _node, kwargs_artifacts, kwargs, current_thread().ident))
def results_reference(return_name):
@ -3228,6 +3281,8 @@ class PipelineDecorator(PipelineController):
launch_thread.join()
except: # noqa
pass
cls._wait_for_node(_node)
if not _node.job:
if not _node.executed:
raise ValueError("Job was not created and is also not cached/executed")
@ -3247,6 +3302,8 @@ class PipelineDecorator(PipelineController):
launch_thread.join()
except: # noqa
pass
cls._wait_for_node(_node)
if (_node.job.is_failed() and not _node.continue_on_fail) or _node.job.is_aborted():
raise ValueError(
'Pipeline step "{}", Task ID={} failed'.format(_node.name, _node.job.task_id())
@ -3501,8 +3558,7 @@ class PipelineDecorator(PipelineController):
for node in list(a_pipeline._nodes.values()):
if node.executed or not node.job or node.job.is_stopped(aborted_nonresponsive_as_running=True):
continue
node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20.,
aborted_nonresponsive_as_running=True)
cls._wait_for_node(node)
waited = True
# store the pipeline result of we have any:
if return_value and pipeline_result is not None:
@ -3589,28 +3645,6 @@ class PipelineDecorator(PipelineController):
"""
return cls._wait_for_multi_pipelines()
@classmethod
def _component_launch_with_failover(cls, node_name, node, kwargs_artifacts, kwargs, tid):
cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid)
while True:
if node.job:
node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20.,
aborted_nonresponsive_as_running=True)
else:
sleep(2)
continue
if node.job.is_failed() and node_name in cls._retries_callbacks and \
cls._retries_callbacks[node_name](cls._singleton, node, cls._retries.get(node_name, 0)):
if cls._singleton and cls._singleton._task:
cls._singleton._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node_name))
node.job = None
node.executed = None
cls._retries[node_name] = cls._retries.get(node_name, 0) + 1
cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid)
else:
break
@classmethod
def _component_launch(cls, node_name, node, kwargs_artifacts, kwargs, tid):
_node_name = node_name

View File

@ -164,6 +164,16 @@ class BaseJob(object):
self._last_status_ts = time()
return self._last_status
def status_message(self):
# type: () -> str
"""
Gets the status message of the task. Note that the message is updated only after `BaseJob.status()`
is called
:return: The status message of the corresponding task as a string
"""
return str(self.task.data.status_message)
def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False):
# type: (Optional[float], float, bool) -> bool
"""