Improve pipeline controller

This commit is contained in:
allegroai 2020-09-09 22:11:18 +03:00
parent 2c47e9f248
commit e206232126

View File

@ -9,7 +9,7 @@ from plotly import graph_objects as go
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
from attr import attrib, attrs from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any from typing import Sequence, Optional, Mapping, Callable, Any, Union
from trains import Task from trains import Task
from trains.automation import TrainsJob from trains.automation import TrainsJob
@ -159,13 +159,14 @@ class PipelineController(object):
return True return True
def start(self, run_remotely=False, step_task_created_callback=None): def start(self, run_remotely=False, step_task_created_callback=None):
# type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool # type: (Union[bool, str], Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
""" """
Start the pipeline controller. Start the pipeline controller.
If the calling process is stopped, then the controller stops as well. If the calling process is stopped, then the controller stops as well.
:param bool run_remotely: (default False), If True stop the current process and continue execution :param bool run_remotely: (default False), If True stop the current process and continue execution
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services' on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'.
If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution.
:param Callable step_task_created_callback: Callback function, called when a step (Task) is created :param Callable step_task_created_callback: Callback function, called when a step (Task) is created
and before it is sent for execution. and before it is sent for execution.
@ -195,10 +196,17 @@ class PipelineController(object):
if not params['continue_pipeline']: if not params['continue_pipeline']:
for k in pipeline_dag: for k in pipeline_dag:
pipeline_dag[k]['executed'] = None pipeline_dag[k]['executed'] = None
self._default_execution_queue = params['default_queue'] self._default_execution_queue = params['default_queue']
self._add_pipeline_tags = params['add_pipeline_tags'] self._add_pipeline_tags = params['add_pipeline_tags']
self._deserialize(pipeline_dag) self._deserialize(pipeline_dag)
# if we continue the pipeline, make sure that we re-execute failed tasks
if params['continue_pipeline']:
for node in self._nodes.values():
if node.executed is False:
node.executed = None
if not self._verify(): if not self._verify():
raise ValueError("Failed verifying pipeline execution graph, " raise ValueError("Failed verifying pipeline execution graph, "
"it has either inaccessible nodes, or contains cycles") "it has either inaccessible nodes, or contains cycles")
@ -206,7 +214,7 @@ class PipelineController(object):
self._update_execution_plot() self._update_execution_plot()
if run_remotely: if run_remotely:
self._task.execute_remotely(queue_name='services') self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
# we will not get here if we are not running remotely # we will not get here if we are not running remotely
self._start_time = time() self._start_time = time()
@ -461,7 +469,7 @@ class PipelineController(object):
'<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
sankey_node['color'].append( sankey_node['color'].append(
("blue" if not node.job or not node.job.is_failed() else "red") ("blue" if not node.job or not node.job.is_failed() else "red")
if node.executed else ("green" if node.job else "lightsteelblue")) if node.executed is not None else ("green" if node.job else "lightsteelblue"))
for p in parents: for p in parents:
sankey_link['source'].append(p) sankey_link['source'].append(p)
@ -592,10 +600,16 @@ class PipelineController(object):
break break
# stop all currently running jobs: # stop all currently running jobs:
failing_pipeline = False
for node in self._nodes.values(): for node in self._nodes.values():
if node.job and not node.executed: if node.executed is False:
failing_pipeline = True
if node.job and node.executed and not node.job.is_stopped():
node.job.abort() node.job.abort()
if failing_pipeline and self._task:
self._task.mark_failed(status_reason='Pipeline step failed')
if self._stop_event: if self._stop_event:
# noinspection PyBroadException # noinspection PyBroadException
try: try: