mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix retry_on_failure callback does nothing when specified on PipelineController.add_step
This commit is contained in:
		
							parent
							
								
									a0e19d833d
								
							
						
					
					
						commit
						d497869a1f
					
				| @ -49,6 +49,7 @@ class PipelineController(object): | ||||
|     _pipeline_section = 'pipeline' | ||||
|     _pipeline_step_ref = 'pipeline' | ||||
|     _runtime_property_hash = '_pipeline_hash' | ||||
|     _relaunch_status_message = "Relaunching pipeline step..." | ||||
|     _reserved_pipeline_names = (_pipeline_step_ref, ) | ||||
|     _task_project_lookup = {} | ||||
|     _clearml_job_class = ClearmlJob | ||||
| @ -61,6 +62,7 @@ class PipelineController(object): | ||||
|     _add_to_evaluated_return_values = {}  # TID: bool | ||||
|     _retries = {}  # Node.name: int | ||||
|     _retries_callbacks = {}  # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool]  # noqa | ||||
|     _final_failure = {}  # Node.name: bool | ||||
| 
 | ||||
|     valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"] | ||||
| 
 | ||||
| @ -1657,6 +1659,23 @@ class PipelineController(object): | ||||
|         # return False if we did not cover all the nodes | ||||
|         return not bool(set(self._nodes.keys()) - visited) | ||||
| 
 | ||||
|     def _relaunch_node(self, node): | ||||
|         if not node.job: | ||||
|             getLogger("clearml.automation.controller").warning( | ||||
|                 "Could not relaunch node {} (job object is missing)".format(node.name) | ||||
|             ) | ||||
|             return | ||||
|         self._retries[node.name] = self._retries.get(node.name, 0) + 1 | ||||
|         getLogger("clearml.automation.controller").warning( | ||||
|             "Node '{}' failed. Retrying... (this is retry number {})".format(node.name, self._retries[node.name]) | ||||
|         ) | ||||
|         node.job.task.mark_stopped(force=True, status_message=self._relaunch_status_message) | ||||
|         node.job.task.set_progress(0) | ||||
|         node.job.task.get_logger().report_text( | ||||
|             "\nNode '{}' failed. Retrying... (this is retry number {})\n".format(node.name, self._retries[node.name]) | ||||
|         ) | ||||
|         node.job.launch(queue_name=node.queue or self._default_execution_queue) | ||||
| 
 | ||||
|     def _launch_node(self, node): | ||||
|         # type: (PipelineController.Node) -> () | ||||
|         """ | ||||
| @ -1909,6 +1928,37 @@ class PipelineController(object): | ||||
| 
 | ||||
|         return table_values | ||||
| 
 | ||||
|     def _call_retries_callback(self, node): | ||||
|         # if this functions returns True, we should relaunch the node | ||||
|         # if False, don't relaunch | ||||
|         if node.name not in self._retries_callbacks: | ||||
|             return False | ||||
|         try: | ||||
|             return self._retries_callbacks[node.name](self, node, self._retries.get(node.name, 0)) | ||||
|         except Exception as e: | ||||
|             getLogger("clearml.automation.controller").warning( | ||||
|                 "Failed calling the retry callback for node '{}'. Error is '{}'".format(node.name, e) | ||||
|             ) | ||||
|             return False | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _wait_for_node(cls, node): | ||||
|         pool_period = 5.0 if cls._debug_execute_step_process else 20.0 | ||||
|         while True: | ||||
|             node.job.wait(pool_period=pool_period, aborted_nonresponsive_as_running=True) | ||||
|             job_status = str(node.job.status(force=True)) | ||||
|             if ( | ||||
|                 ( | ||||
|                     job_status == str(Task.TaskStatusEnum.stopped) | ||||
|                     and node.job.status_message() == cls._relaunch_status_message | ||||
|                 ) | ||||
|                 or (job_status == str(Task.TaskStatusEnum.failed) and not cls._final_failure.get(node.name)) | ||||
|                 or not node.job.is_stopped() | ||||
|             ): | ||||
|                 sleep(pool_period) | ||||
|             else: | ||||
|                 break | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_node_color(cls, node): | ||||
|         # type (self.Mode) -> str | ||||
| @ -2051,15 +2101,12 @@ class PipelineController(object): | ||||
|                     continue | ||||
|                 if node.job.is_stopped(aborted_nonresponsive_as_running=True): | ||||
|                     node_failed = node.job.is_failed() | ||||
|                     if node_failed and \ | ||||
|                             self._retry_on_failure_callback(self, node, self._retries.get(node.name, 0)): | ||||
| 
 | ||||
|                         self._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node.name)) | ||||
|                         node.job = None | ||||
|                         node.executed = None | ||||
|                         self._running_nodes.remove(j) | ||||
|                         self._retries[node.name] = self._retries.get(node.name, 0) + 1 | ||||
|                         continue | ||||
|                     if node_failed: | ||||
|                         if self._call_retries_callback(node): | ||||
|                             self._relaunch_node(node) | ||||
|                             continue | ||||
|                         else: | ||||
|                             self._final_failure[node.name] = True | ||||
| 
 | ||||
|                     completed_jobs.append(j) | ||||
|                     node.executed = node.job.task_id() if not node_failed else False | ||||
| @ -2718,6 +2765,12 @@ class PipelineDecorator(PipelineController): | ||||
|                     continue | ||||
|                 if node.job.is_stopped(aborted_nonresponsive_as_running=True): | ||||
|                     node_failed = node.job.is_failed() | ||||
|                     if node_failed: | ||||
|                         if self._call_retries_callback(node): | ||||
|                             self._relaunch_node(node) | ||||
|                             continue | ||||
|                         else: | ||||
|                             self._final_failure[node.name] = True | ||||
|                     completed_jobs.append(j) | ||||
|                     node.executed = node.job.task_id() if not node_failed else False | ||||
|                     if j in launched_nodes: | ||||
| @ -3218,7 +3271,7 @@ class PipelineDecorator(PipelineController): | ||||
| 
 | ||||
|                 # The actual launch is a bit slow, we run it in the background | ||||
|                 launch_thread = Thread( | ||||
|                     target=cls._component_launch_with_failover, | ||||
|                     target=cls._component_launch, | ||||
|                     args=(_node_name, _node, kwargs_artifacts, kwargs, current_thread().ident)) | ||||
| 
 | ||||
|                 def results_reference(return_name): | ||||
| @ -3228,6 +3281,8 @@ class PipelineDecorator(PipelineController): | ||||
|                             launch_thread.join() | ||||
|                         except:  # noqa | ||||
|                             pass | ||||
| 
 | ||||
|                     cls._wait_for_node(_node) | ||||
|                     if not _node.job: | ||||
|                         if not _node.executed: | ||||
|                             raise ValueError("Job was not created and is also not cached/executed") | ||||
| @ -3247,6 +3302,8 @@ class PipelineDecorator(PipelineController): | ||||
|                             launch_thread.join() | ||||
|                         except:  # noqa | ||||
|                             pass | ||||
| 
 | ||||
|                     cls._wait_for_node(_node) | ||||
|                     if (_node.job.is_failed() and not _node.continue_on_fail) or _node.job.is_aborted(): | ||||
|                         raise ValueError( | ||||
|                             'Pipeline step "{}", Task ID={} failed'.format(_node.name, _node.job.task_id()) | ||||
| @ -3501,8 +3558,7 @@ class PipelineDecorator(PipelineController): | ||||
|                     for node in list(a_pipeline._nodes.values()): | ||||
|                         if node.executed or not node.job or node.job.is_stopped(aborted_nonresponsive_as_running=True): | ||||
|                             continue | ||||
|                         node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20., | ||||
|                                       aborted_nonresponsive_as_running=True) | ||||
|                         cls._wait_for_node(node) | ||||
|                         waited = True | ||||
|                 # store the pipeline result of we have any: | ||||
|                 if return_value and pipeline_result is not None: | ||||
| @ -3589,28 +3645,6 @@ class PipelineDecorator(PipelineController): | ||||
|         """ | ||||
|         return cls._wait_for_multi_pipelines() | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _component_launch_with_failover(cls, node_name, node, kwargs_artifacts, kwargs, tid): | ||||
|         cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid) | ||||
|         while True: | ||||
|             if node.job: | ||||
|                 node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20., | ||||
|                               aborted_nonresponsive_as_running=True) | ||||
|             else: | ||||
|                 sleep(2) | ||||
|                 continue | ||||
|             if node.job.is_failed() and node_name in cls._retries_callbacks and \ | ||||
|                     cls._retries_callbacks[node_name](cls._singleton, node, cls._retries.get(node_name, 0)): | ||||
| 
 | ||||
|                 if cls._singleton and cls._singleton._task: | ||||
|                     cls._singleton._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node_name)) | ||||
|                 node.job = None | ||||
|                 node.executed = None | ||||
|                 cls._retries[node_name] = cls._retries.get(node_name, 0) + 1 | ||||
|                 cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid) | ||||
|             else: | ||||
|                 break | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _component_launch(cls, node_name, node, kwargs_artifacts, kwargs, tid): | ||||
|         _node_name = node_name | ||||
|  | ||||
| @ -164,6 +164,16 @@ class BaseJob(object): | ||||
|         self._last_status_ts = time() | ||||
|         return self._last_status | ||||
| 
 | ||||
|     def status_message(self): | ||||
|         # type: () -> str | ||||
|         """ | ||||
|         Gets the status message of the task. Note that the message is updated only after `BaseJob.status()` | ||||
|         is called | ||||
| 
 | ||||
|         :return: The status message of the corresponding task as a string | ||||
|         """ | ||||
|         return str(self.task.data.status_message) | ||||
| 
 | ||||
|     def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False): | ||||
|         # type: (Optional[float], float, bool) -> bool | ||||
|         """ | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai