diff --git a/trains/automation/controller.py b/trains/automation/controller.py index 12a350c4..40048f94 100644 --- a/trains/automation/controller.py +++ b/trains/automation/controller.py @@ -9,7 +9,7 @@ from plotly import graph_objects as go from plotly.subplots import make_subplots from attr import attrib, attrs -from typing import Sequence, Optional, Mapping, Callable, Any +from typing import Sequence, Optional, Mapping, Callable, Any, Union from trains import Task from trains.automation import TrainsJob @@ -159,13 +159,14 @@ class PipelineController(object): return True def start(self, run_remotely=False, step_task_created_callback=None): - # type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool + # type: (Union[bool, str], Optional[Callable[[PipelineController.Node, dict], None]]) -> bool """ Start the pipeline controller. If the calling process is stopped, then the controller stops as well. :param bool run_remotely: (default False), If True stop the current process and continue execution - on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services' + on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'. + If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution. :param Callable step_task_created_callback: Callback function, called when a step (Task) is created and before it is sent for execution. @@ -195,10 +196,17 @@ class PipelineController(object): if not params['continue_pipeline']: for k in pipeline_dag: pipeline_dag[k]['executed'] = None + self._default_execution_queue = params['default_queue'] self._add_pipeline_tags = params['add_pipeline_tags'] self._deserialize(pipeline_dag) + # if we continue the pipeline, make sure that we re-execute failed tasks + if params['continue_pipeline']: + for node in self._nodes.values(): + if node.executed is False: + node.executed = None + if not self._verify(): raise ValueError("Failed verifying pipeline execution graph, " "it has either inaccessible nodes, or contains cycles") @@ -206,7 +214,7 @@ class PipelineController(object): self._update_execution_plot() if run_remotely: - self._task.execute_remotely(queue_name='services') + self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely) # we will not get here if we are not running remotely self._start_time = time() @@ -461,7 +469,7 @@ class PipelineController(object): '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) sankey_node['color'].append( ("blue" if not node.job or not node.job.is_failed() else "red") - if node.executed else ("green" if node.job else "lightsteelblue")) + if node.executed is not None else ("green" if node.job else "lightsteelblue")) for p in parents: sankey_link['source'].append(p) @@ -592,10 +600,16 @@ class PipelineController(object): break # stop all currently running jobs: + failing_pipeline = False for node in self._nodes.values(): - if node.job and not node.executed: + if node.executed is False: + failing_pipeline = True + if node.job and node.executed and not node.job.is_stopped(): node.job.abort() + if failing_pipeline and self._task: + self._task.mark_failed(status_reason='Pipeline step failed') + if self._stop_event: # noinspection PyBroadException try: