From 8a44c2a8b7411d767d164d4fe2007c53c6e19a68 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 22 Sep 2022 20:54:19 +0300 Subject: [PATCH] Add callback option for pipeline step retry --- clearml/automation/controller.py | 147 ++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 13 deletions(-) diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index 928412e9..3fc09835 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -60,6 +60,7 @@ class PipelineController(object): _evaluated_return_values = {} # TID: pipeline_name _add_to_evaluated_return_values = {} # TID: bool _retries_left = {} # Node.name: int + _retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"] @@ -133,7 +134,8 @@ class PipelineController(object): auto_version_bump=True, # type: bool abort_on_failure=False, # type: bool add_run_number=True, # type: bool - retry_on_failure=None # type: Optional[int] + retry_on_failure=None, # type: Optional[int] + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> None """ @@ -161,6 +163,21 @@ class PipelineController(object): Example, the second time we launch the pipeline "best pipeline", we rename it to "best pipeline #2" :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed + (default 0, can be set with `retry_on_failure` parameter). + The function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 """ self._nodes = {} self._running_nodes = [] @@ -227,6 +244,7 @@ class PipelineController(object): self._monitored_nodes = {} # type: Dict[str, dict] self._abort_running_steps_on_failure = abort_on_failure self._retry_on_failure = retry_on_failure + self._retry_on_failure_callback = retry_on_failure_callback or PipelineController._default_retry_on_failure_callback # add direct link to the pipeline page if self._pipeline_as_sub_project and self._task: @@ -281,7 +299,8 @@ class PipelineController(object): post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa cache_executed_step=False, # type: bool base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]] - retry_on_failure=None # type: int + retry_on_failure=None, # type: int + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> bool """ @@ -402,6 +421,21 @@ class PipelineController(object): provide a Callable function to create the Task (returns Task object) :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed + (default 0, can be set with `retry_on_failure` parameter). + The function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 :return: True if successful """ @@ -466,6 +500,11 @@ class PipelineController(object): monitor_models=monitor_models or [], ) self._retries_left[name] = retry_on_failure or self._retry_on_failure or 0 + self._retries_callbacks[name] = ( + retry_on_failure_callback + or self._retry_on_failure_callback + or PipelineController._default_retry_on_failure_callback + ) if self._task and not self._task.running_locally(): self.update_execution_plot() @@ -501,7 +540,8 @@ class PipelineController(object): pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa cache_executed_step=False, # type: bool - retry_on_failure=None # type: Optional[int] + retry_on_failure=None, # type: Optional[int] + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> bool """ @@ -637,6 +677,21 @@ class PipelineController(object): :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed + (default 0, can be set with `retry_on_failure` parameter). + The function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 :return: True if successful """ # always store callback functions (even when running remotely) @@ -730,6 +785,11 @@ class PipelineController(object): job_code_section=job_code_section, ) self._retries_left[name] = retry_on_failure or self._retry_on_failure or 0 + self._retries_callbacks[name] = ( + retry_on_failure_callback + or self._retry_on_failure_callback + or PipelineController._default_retry_on_failure_callback + ) return True @@ -1988,12 +2048,13 @@ class PipelineController(object): continue if node.job.is_stopped(aborted_nonresponsive_as_running=True): node_failed = node.job.is_failed() - if node_failed and self._retries_left.get(node.name) and self._retries_left[node.name] > 0: + if node_failed and self._retry_on_failure_callback(self, node, self._retries_left.get(node.name, 0)): self._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node.name)) node.job = None node.executed = None self._running_nodes.remove(j) - self._retries_left[node.name] -= 1 + if node.name in self._retries_left: + self._retries_left[node.name] -= 1 continue completed_jobs.append(j) node.executed = node.job.task_id() if not node_failed else False @@ -2521,6 +2582,10 @@ class PipelineController(object): return ' {} '.format(task_link_template.format(project=project_id, task=task_id), task_id) + @staticmethod + def _default_retry_on_failure_callback(_pipeline_controller, _node, retries): + return retries > 0 + class PipelineDecorator(PipelineController): _added_decorator = [] # type: List[dict] @@ -2545,7 +2610,8 @@ class PipelineDecorator(PipelineController): target_project=None, # type: Optional[str] abort_on_failure=False, # type: bool add_run_number=True, # type: bool - retry_on_failure=None # type: Optional[int] + retry_on_failure=None, # type: Optional[int] + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> () """ @@ -2569,6 +2635,20 @@ class PipelineDecorator(PipelineController): Example, the second time we launch the pipeline "best pipeline", we rename it to "best pipeline #2" :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed (default 0). The + function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 """ super(PipelineDecorator, self).__init__( name=name, @@ -2579,7 +2659,8 @@ class PipelineDecorator(PipelineController): target_project=target_project, abort_on_failure=abort_on_failure, add_run_number=add_run_number, - retry_on_failure=retry_on_failure + retry_on_failure=retry_on_failure, + retry_on_failure_callback=retry_on_failure_callback ) # if we are in eager execution, make sure parent class knows it @@ -2888,7 +2969,8 @@ class PipelineDecorator(PipelineController): monitor_metrics=None, # type: Optional[List[Union[Tuple[str, str], Tuple[(str, str), (str, str)]]]] monitor_artifacts=None, # type: Optional[List[Union[str, Tuple[str, str]]]] monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]] - retry_on_failure=None # type: Optional[int] + retry_on_failure=None, # type: Optional[int] + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> Callable """ @@ -2964,6 +3046,20 @@ class PipelineDecorator(PipelineController): Example: [('model_weights', 'final_model_weights'), ] :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed (default 0). The + function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 :return: function wrapper """ @@ -3115,6 +3211,11 @@ class PipelineDecorator(PipelineController): _node = cls._singleton._nodes[_node_name] cls._retries_left[_node_name] = \ retry_on_failure or (cls._singleton._retry_on_failure if cls._singleton else 0) or 0 + cls._retries_callbacks[_node_name] = retry_on_failure_callback or ( + cls._singleton._retry_on_failure_callback + if cls._singleton + else cls._default_retry_on_failure_callback + ) # The actual launch is a bit slow, we run it in the background launch_thread = Thread( @@ -3197,7 +3298,8 @@ class PipelineDecorator(PipelineController): add_run_number=True, # type: bool args_map=None, # type: dict[str, List[str]] start_controller_locally=False, # type: bool - retry_on_failure=None # type: Optional[int] + retry_on_failure=None, # type: Optional[int] + retry_on_failure_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, int], bool]] # noqa ): # type: (...) -> Callable """ @@ -3252,6 +3354,20 @@ class PipelineDecorator(PipelineController): Default: False :param retry_on_failure: In case of node failure, retry the node the number of times indicated by this parameter. + :param retry_on_failure_callback: A function called on node failure. Takes as parameters: + the PipelineController instance, the PipelineController.Node that failed and an int + representing the number of retries left for the node that failed (default 0). The + function must return a `bool`: True if the node should be retried and False otherwise. + If True, the node will be requeued and the number of retries left will be decremented by 1. + By default, if this callback is not specified, the function will be retried the number of + times indicated by `retry_on_failure`. + + .. code-block:: py + + def example_retry_on_failure_callback(pipeline, node, retries): + print(node.name, ' failed') + # do something with the pipeline controller + return retries > 0 """ def decorator_wrap(func): @@ -3288,7 +3404,8 @@ class PipelineDecorator(PipelineController): target_project=target_project, abort_on_failure=abort_on_failure, add_run_number=add_run_number, - retry_on_failure=retry_on_failure + retry_on_failure=retry_on_failure, + retry_on_failure_callback=retry_on_failure_callback ) ret_val = func(**pipeline_kwargs) LazyEvalWrapper.trigger_all_remote_references() @@ -3330,7 +3447,8 @@ class PipelineDecorator(PipelineController): target_project=target_project, abort_on_failure=abort_on_failure, add_run_number=add_run_number, - retry_on_failure=retry_on_failure + retry_on_failure=retry_on_failure, + retry_on_failure_callback=retry_on_failure_callback ) a_pipeline._args_map = args_map or {} @@ -3484,13 +3602,16 @@ class PipelineDecorator(PipelineController): else: sleep(2) continue - if node.job.is_failed() and cls._retries_left.get(node_name) and cls._retries_left[node_name] > 0: + if node.job.is_failed() and cls._retries_callbacks.get(node_name, cls._default_retry_on_failure_callback)( + cls._singleton, node, cls._retries_left.get(node_name, 0) + ): if cls._singleton and cls._singleton._task: cls._singleton._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node_name)) node.job = None node.executed = None cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid) - cls._retries_left[node_name] -= 1 + if node_name in cls._retries_left: + cls._retries_left[node_name] -= 1 else: break