Improve pipeline controller

This commit is contained in:
allegroai 2020-09-09 22:11:18 +03:00
parent 2c47e9f248
commit e206232126

View File

@ -9,7 +9,7 @@ from plotly import graph_objects as go
from plotly.subplots import make_subplots
from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any
from typing import Sequence, Optional, Mapping, Callable, Any, Union
from trains import Task
from trains.automation import TrainsJob
@ -159,13 +159,14 @@ class PipelineController(object):
return True
def start(self, run_remotely=False, step_task_created_callback=None):
# type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
# type: (Union[bool, str], Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
"""
Start the pipeline controller.
If the calling process is stopped, then the controller stops as well.
:param bool run_remotely: (default False), If True stop the current process and continue execution
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'.
If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution.
:param Callable step_task_created_callback: Callback function, called when a step (Task) is created
and before it is sent for execution.
@ -195,10 +196,17 @@ class PipelineController(object):
if not params['continue_pipeline']:
for k in pipeline_dag:
pipeline_dag[k]['executed'] = None
self._default_execution_queue = params['default_queue']
self._add_pipeline_tags = params['add_pipeline_tags']
self._deserialize(pipeline_dag)
# if we continue the pipeline, make sure that we re-execute failed tasks
if params['continue_pipeline']:
for node in self._nodes.values():
if node.executed is False:
node.executed = None
if not self._verify():
raise ValueError("Failed verifying pipeline execution graph, "
"it has either inaccessible nodes, or contains cycles")
@ -206,7 +214,7 @@ class PipelineController(object):
self._update_execution_plot()
if run_remotely:
self._task.execute_remotely(queue_name='services')
self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
# we will not get here if we are not running remotely
self._start_time = time()
@ -461,7 +469,7 @@ class PipelineController(object):
'<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
sankey_node['color'].append(
("blue" if not node.job or not node.job.is_failed() else "red")
if node.executed else ("green" if node.job else "lightsteelblue"))
if node.executed is not None else ("green" if node.job else "lightsteelblue"))
for p in parents:
sankey_link['source'].append(p)
@ -592,10 +600,16 @@ class PipelineController(object):
break
# stop all currently running jobs:
failing_pipeline = False
for node in self._nodes.values():
if node.job and not node.executed:
if node.executed is False:
failing_pipeline = True
if node.job and node.executed and not node.job.is_stopped():
node.job.abort()
if failing_pipeline and self._task:
self._task.mark_failed(status_reason='Pipeline step failed')
if self._stop_event:
# noinspection PyBroadException
try: