mirror of
https://github.com/clearml/clearml
synced 2025-04-22 07:15:57 +00:00
Fix retry_on_failure callback does nothing when specified on PipelineController.add_step
This commit is contained in:
parent
a0e19d833d
commit
d497869a1f
@ -49,6 +49,7 @@ class PipelineController(object):
|
|||||||
_pipeline_section = 'pipeline'
|
_pipeline_section = 'pipeline'
|
||||||
_pipeline_step_ref = 'pipeline'
|
_pipeline_step_ref = 'pipeline'
|
||||||
_runtime_property_hash = '_pipeline_hash'
|
_runtime_property_hash = '_pipeline_hash'
|
||||||
|
_relaunch_status_message = "Relaunching pipeline step..."
|
||||||
_reserved_pipeline_names = (_pipeline_step_ref, )
|
_reserved_pipeline_names = (_pipeline_step_ref, )
|
||||||
_task_project_lookup = {}
|
_task_project_lookup = {}
|
||||||
_clearml_job_class = ClearmlJob
|
_clearml_job_class = ClearmlJob
|
||||||
@ -61,6 +62,7 @@ class PipelineController(object):
|
|||||||
_add_to_evaluated_return_values = {} # TID: bool
|
_add_to_evaluated_return_values = {} # TID: bool
|
||||||
_retries = {} # Node.name: int
|
_retries = {} # Node.name: int
|
||||||
_retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa
|
_retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa
|
||||||
|
_final_failure = {} # Node.name: bool
|
||||||
|
|
||||||
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
|
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
|
||||||
|
|
||||||
@ -1657,6 +1659,23 @@ class PipelineController(object):
|
|||||||
# return False if we did not cover all the nodes
|
# return False if we did not cover all the nodes
|
||||||
return not bool(set(self._nodes.keys()) - visited)
|
return not bool(set(self._nodes.keys()) - visited)
|
||||||
|
|
||||||
|
def _relaunch_node(self, node):
|
||||||
|
if not node.job:
|
||||||
|
getLogger("clearml.automation.controller").warning(
|
||||||
|
"Could not relaunch node {} (job object is missing)".format(node.name)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._retries[node.name] = self._retries.get(node.name, 0) + 1
|
||||||
|
getLogger("clearml.automation.controller").warning(
|
||||||
|
"Node '{}' failed. Retrying... (this is retry number {})".format(node.name, self._retries[node.name])
|
||||||
|
)
|
||||||
|
node.job.task.mark_stopped(force=True, status_message=self._relaunch_status_message)
|
||||||
|
node.job.task.set_progress(0)
|
||||||
|
node.job.task.get_logger().report_text(
|
||||||
|
"\nNode '{}' failed. Retrying... (this is retry number {})\n".format(node.name, self._retries[node.name])
|
||||||
|
)
|
||||||
|
node.job.launch(queue_name=node.queue or self._default_execution_queue)
|
||||||
|
|
||||||
def _launch_node(self, node):
|
def _launch_node(self, node):
|
||||||
# type: (PipelineController.Node) -> ()
|
# type: (PipelineController.Node) -> ()
|
||||||
"""
|
"""
|
||||||
@ -1909,6 +1928,37 @@ class PipelineController(object):
|
|||||||
|
|
||||||
return table_values
|
return table_values
|
||||||
|
|
||||||
|
def _call_retries_callback(self, node):
|
||||||
|
# if this functions returns True, we should relaunch the node
|
||||||
|
# if False, don't relaunch
|
||||||
|
if node.name not in self._retries_callbacks:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return self._retries_callbacks[node.name](self, node, self._retries.get(node.name, 0))
|
||||||
|
except Exception as e:
|
||||||
|
getLogger("clearml.automation.controller").warning(
|
||||||
|
"Failed calling the retry callback for node '{}'. Error is '{}'".format(node.name, e)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _wait_for_node(cls, node):
|
||||||
|
pool_period = 5.0 if cls._debug_execute_step_process else 20.0
|
||||||
|
while True:
|
||||||
|
node.job.wait(pool_period=pool_period, aborted_nonresponsive_as_running=True)
|
||||||
|
job_status = str(node.job.status(force=True))
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
job_status == str(Task.TaskStatusEnum.stopped)
|
||||||
|
and node.job.status_message() == cls._relaunch_status_message
|
||||||
|
)
|
||||||
|
or (job_status == str(Task.TaskStatusEnum.failed) and not cls._final_failure.get(node.name))
|
||||||
|
or not node.job.is_stopped()
|
||||||
|
):
|
||||||
|
sleep(pool_period)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_node_color(cls, node):
|
def _get_node_color(cls, node):
|
||||||
# type (self.Mode) -> str
|
# type (self.Mode) -> str
|
||||||
@ -2051,15 +2101,12 @@ class PipelineController(object):
|
|||||||
continue
|
continue
|
||||||
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
||||||
node_failed = node.job.is_failed()
|
node_failed = node.job.is_failed()
|
||||||
if node_failed and \
|
if node_failed:
|
||||||
self._retry_on_failure_callback(self, node, self._retries.get(node.name, 0)):
|
if self._call_retries_callback(node):
|
||||||
|
self._relaunch_node(node)
|
||||||
self._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node.name))
|
continue
|
||||||
node.job = None
|
else:
|
||||||
node.executed = None
|
self._final_failure[node.name] = True
|
||||||
self._running_nodes.remove(j)
|
|
||||||
self._retries[node.name] = self._retries.get(node.name, 0) + 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
completed_jobs.append(j)
|
completed_jobs.append(j)
|
||||||
node.executed = node.job.task_id() if not node_failed else False
|
node.executed = node.job.task_id() if not node_failed else False
|
||||||
@ -2718,6 +2765,12 @@ class PipelineDecorator(PipelineController):
|
|||||||
continue
|
continue
|
||||||
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
if node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
||||||
node_failed = node.job.is_failed()
|
node_failed = node.job.is_failed()
|
||||||
|
if node_failed:
|
||||||
|
if self._call_retries_callback(node):
|
||||||
|
self._relaunch_node(node)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self._final_failure[node.name] = True
|
||||||
completed_jobs.append(j)
|
completed_jobs.append(j)
|
||||||
node.executed = node.job.task_id() if not node_failed else False
|
node.executed = node.job.task_id() if not node_failed else False
|
||||||
if j in launched_nodes:
|
if j in launched_nodes:
|
||||||
@ -3218,7 +3271,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
|
|
||||||
# The actual launch is a bit slow, we run it in the background
|
# The actual launch is a bit slow, we run it in the background
|
||||||
launch_thread = Thread(
|
launch_thread = Thread(
|
||||||
target=cls._component_launch_with_failover,
|
target=cls._component_launch,
|
||||||
args=(_node_name, _node, kwargs_artifacts, kwargs, current_thread().ident))
|
args=(_node_name, _node, kwargs_artifacts, kwargs, current_thread().ident))
|
||||||
|
|
||||||
def results_reference(return_name):
|
def results_reference(return_name):
|
||||||
@ -3228,6 +3281,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
launch_thread.join()
|
launch_thread.join()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
cls._wait_for_node(_node)
|
||||||
if not _node.job:
|
if not _node.job:
|
||||||
if not _node.executed:
|
if not _node.executed:
|
||||||
raise ValueError("Job was not created and is also not cached/executed")
|
raise ValueError("Job was not created and is also not cached/executed")
|
||||||
@ -3247,6 +3302,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
launch_thread.join()
|
launch_thread.join()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
cls._wait_for_node(_node)
|
||||||
if (_node.job.is_failed() and not _node.continue_on_fail) or _node.job.is_aborted():
|
if (_node.job.is_failed() and not _node.continue_on_fail) or _node.job.is_aborted():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pipeline step "{}", Task ID={} failed'.format(_node.name, _node.job.task_id())
|
'Pipeline step "{}", Task ID={} failed'.format(_node.name, _node.job.task_id())
|
||||||
@ -3501,8 +3558,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
for node in list(a_pipeline._nodes.values()):
|
for node in list(a_pipeline._nodes.values()):
|
||||||
if node.executed or not node.job or node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
if node.executed or not node.job or node.job.is_stopped(aborted_nonresponsive_as_running=True):
|
||||||
continue
|
continue
|
||||||
node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20.,
|
cls._wait_for_node(node)
|
||||||
aborted_nonresponsive_as_running=True)
|
|
||||||
waited = True
|
waited = True
|
||||||
# store the pipeline result of we have any:
|
# store the pipeline result of we have any:
|
||||||
if return_value and pipeline_result is not None:
|
if return_value and pipeline_result is not None:
|
||||||
@ -3589,28 +3645,6 @@ class PipelineDecorator(PipelineController):
|
|||||||
"""
|
"""
|
||||||
return cls._wait_for_multi_pipelines()
|
return cls._wait_for_multi_pipelines()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _component_launch_with_failover(cls, node_name, node, kwargs_artifacts, kwargs, tid):
|
|
||||||
cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid)
|
|
||||||
while True:
|
|
||||||
if node.job:
|
|
||||||
node.job.wait(pool_period=5. if cls._debug_execute_step_process else 20.,
|
|
||||||
aborted_nonresponsive_as_running=True)
|
|
||||||
else:
|
|
||||||
sleep(2)
|
|
||||||
continue
|
|
||||||
if node.job.is_failed() and node_name in cls._retries_callbacks and \
|
|
||||||
cls._retries_callbacks[node_name](cls._singleton, node, cls._retries.get(node_name, 0)):
|
|
||||||
|
|
||||||
if cls._singleton and cls._singleton._task:
|
|
||||||
cls._singleton._task.get_logger().report_text("Node '{}' failed. Retrying...".format(node_name))
|
|
||||||
node.job = None
|
|
||||||
node.executed = None
|
|
||||||
cls._retries[node_name] = cls._retries.get(node_name, 0) + 1
|
|
||||||
cls._component_launch(node_name, node, kwargs_artifacts, kwargs, tid)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _component_launch(cls, node_name, node, kwargs_artifacts, kwargs, tid):
|
def _component_launch(cls, node_name, node, kwargs_artifacts, kwargs, tid):
|
||||||
_node_name = node_name
|
_node_name = node_name
|
||||||
|
@ -164,6 +164,16 @@ class BaseJob(object):
|
|||||||
self._last_status_ts = time()
|
self._last_status_ts = time()
|
||||||
return self._last_status
|
return self._last_status
|
||||||
|
|
||||||
|
def status_message(self):
|
||||||
|
# type: () -> str
|
||||||
|
"""
|
||||||
|
Gets the status message of the task. Note that the message is updated only after `BaseJob.status()`
|
||||||
|
is called
|
||||||
|
|
||||||
|
:return: The status message of the corresponding task as a string
|
||||||
|
"""
|
||||||
|
return str(self.task.data.status_message)
|
||||||
|
|
||||||
def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False):
|
def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False):
|
||||||
# type: (Optional[float], float, bool) -> bool
|
# type: (Optional[float], float, bool) -> bool
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user