From d497869a1fc35bccfe4f17ec61e1a6687f4f7174 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 14 Oct 2022 10:14:29 +0300 Subject: [PATCH] Fix retry_on_failure callback does nothing when specified on PipelineController.add_step --- clearml/automation/controller.py | 102 ++++++++++++++++++++----------- clearml/automation/job.py | 10 +++ 2 files changed, 78 insertions(+), 34 deletions(-) diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index cdb221e0..762de15c 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -49,6 +49,7 @@ class PipelineController(object): _pipeline_section = 'pipeline' _pipeline_step_ref = 'pipeline' _runtime_property_hash = '_pipeline_hash' + _relaunch_status_message = "Relaunching pipeline step..." _reserved_pipeline_names = (_pipeline_step_ref, ) _task_project_lookup = {} _clearml_job_class = ClearmlJob @@ -61,6 +62,7 @@ class PipelineController(object): _add_to_evaluated_return_values = {} # TID: bool _retries = {} # Node.name: int _retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa + _final_failure = {} # Node.name: bool valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"] @@ -1657,6 +1659,23 @@ class PipelineController(object): # return False if we did not cover all the nodes return not bool(set(self._nodes.keys()) - visited) + def _relaunch_node(self, node): + if not node.job: + getLogger("clearml.automation.controller").warning( + "Could not relaunch node {} (job object is missing)".format(node.name) + ) + return + self._retries[node.name] = self._retries.get(node.name, 0) + 1 + getLogger("clearml.automation.controller").warning( + "Node '{}' failed. Retrying... (this is retry number {})".format(node.name, self._retries[node.name]) + ) + node.job.task.mark_stopped(force=True, status_message=self._relaunch_status_message) + node.job.task.set_progress(0) + node.job.task.get_logger().report_text( + "\nNode '{}' failed. Retrying... (this is retry number {})\n".format(node.name, self._retries[node.name]) + ) + node.job.launch(queue_name=node.queue or self._default_execution_queue) + def _launch_node(self, node): # type: (PipelineController.Node) -> () """ @@ -1909,6 +1928,37 @@ class PipelineController(object): return table_values + def _call_retries_callback(self, node): + # if this functions returns True, we should relaunch the node + # if False, don't relaunch + if node.name not in self._retries_callbacks: + return False + try: + return self._retries_callbacks[node.name](self, node, self._retries.get(node.name, 0)) + except Exception as e: + getLogger("clearml.automation.controller").warning( + "Failed calling the retry callback for node '{}'. Error is '{}'".format(node.name, e) + ) + return False + + @classmethod + def _wait_for_node(cls, node): + pool_period = 5.0 if cls._debug_execute_step_process else 20.0 + while True: + node.job.wait(pool_period=pool_period, aborted_nonresponsive_as_running=True) + job_status = str(node.job.status(force=True)) + if ( + ( + job_status == str(Task.TaskStatusEnum.stopped) + and node.job.status_message() == cls._relaunch_status_message + ) + or (job_status == str(Task.TaskStatusEnum.failed) and not cls._final_failure.get(node.name)) + or not node.job.is_stopped() + ): + sleep(pool_period) + else: + break + @classmethod def _get_node_color(cls, node): # type (self.Mode) -> str @@ -2051,15 +2101,12 @@ class PipelineController(object): continue if node.job.is_stopped(aborted_nonresponsive_as_running=True): node_failed = node.job.is_failed() - if node_failed and \ - self._retry_on_failure_callback(self, node, self._retries.get(node.name, 0)): - - self._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node.name)) - node.job = None - node.executed = None - self._running_nodes.remove(j) - self._retries[node.name] = self._retries.get(node.name, 0) + 1 - continue + if node_failed: + if self._call_retries_callback(node): + self._relaunch_node(node) + continue + else: + self._final_failure[node.name] = True completed_jobs.append(j) node.executed = node.job.task_id() if not node_failed else False @@ -2718,6 +2765,12 @@ class PipelineDecorator(PipelineController): continue if node.job.is_stopped(aborted_nonresponsive_as_running=True): node_failed = node.job.is_failed() + if node_failed: + if self._call_retries_callback(node): + self._relaunch_node(node) + continue + else: + self._final_failure[node.name] = True completed_jobs.append(j) node.executed = node.job.task_id() if not node_failed else False if j in launched_nodes: @@ -3218,7 +3271,7 @@ class PipelineDecorator(PipelineController): # The actual launch is a bit slow, we run it in the background launch_thread = Thread( - target=cls._component_launch_with_failover, + target=cls._component_launch, args=(_node_name, _node, kwargs_artifacts, kwargs, current_thread().ident)) def results_reference(return_name): @@ -3228,6 +3281,8 @@ class PipelineDecorator(PipelineController): launch_thread.join() except: # noqa pass + + cls._wait_for_node(_node) if not _node.job: if not _node.executed: raise ValueError("Job was not created and is also not cached/executed") @@ -3247,6 +3302,8 @@ class PipelineDecorator(PipelineController): launch_thread.join() except: # noqa pass + + cls._wait_for_node(_node) if (_node.job.is_failed() and not _node.continue_on_fail) or _node.job.is_aborted(): raise ValueError( 'Pipeline step "{}", Task ID={} failed'.format(_node.name, _node.job.task_id()) @@ -3501,8 +3558,7 @@ class PipelineDecorator(PipelineController): for node in list(a_pipeline._nodes.values()): if node.executed or not node.job or node.job.is_stopped(aborted_nonresponsive_as_running=True): continue - node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20., - aborted_nonresponsive_as_running=True) + cls._wait_for_node(node) waited = True # store the pipeline result of we have any: if return_value and pipeline_result is not None: @@ -3589,28 +3645,6 @@ class PipelineDecorator(PipelineController): """ return cls._wait_for_multi_pipelines() - @classmethod - def _component_launch_with_failover(cls, node_name, node, kwargs_artifacts, kwargs, tid): - cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid) - while True: - if node.job: - node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20., - aborted_nonresponsive_as_running=True) - else: - sleep(2) - continue - if node.job.is_failed() and node_name in cls._retries_callbacks and \ - cls._retries_callbacks[node_name](cls._singleton, node, cls._retries.get(node_name, 0)): - - if cls._singleton and cls._singleton._task: - cls._singleton._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node_name)) - node.job = None - node.executed = None - cls._retries[node_name] = cls._retries.get(node_name, 0) + 1 - cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid) - else: - break - @classmethod def _component_launch(cls, node_name, node, kwargs_artifacts, kwargs, tid): _node_name = node_name diff --git a/clearml/automation/job.py b/clearml/automation/job.py index 6a2b7e08..9d2829ba 100644 --- a/clearml/automation/job.py +++ b/clearml/automation/job.py @@ -164,6 +164,16 @@ class BaseJob(object): self._last_status_ts = time() return self._last_status + def status_message(self): + # type: () -> str + """ + Gets the status message of the task. Note that the message is updated only after `BaseJob.status()` + is called + + :return: The status message of the corresponding task as a string + """ + return str(self.task.data.status_message) + def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False): # type: (Optional[float], float, bool) -> bool """