Add PipelineController node skip visualization and post execution callback

This commit is contained in:
allegroai 2021-03-25 01:00:56 +02:00
parent 9ff52a8699
commit c76fe55c03
2 changed files with 130 additions and 34 deletions

View File

@ -2,7 +2,7 @@ import re
from copy import copy
from datetime import datetime
from logging import getLogger
from threading import Thread, Event
from threading import Thread, Event, RLock
from time import time
from attr import attrib, attrs
@ -24,6 +24,7 @@ class PipelineController(object):
_tag = 'pipeline'
_step_pattern = r"\${[^}]*}"
_config_section = 'Pipeline'
_task_project_lookup = {}
@attrs
class Node(object):
@ -36,6 +37,7 @@ class PipelineController(object):
executed = attrib(type=str, default=None)
clone_task = attrib(type=bool, default=True)
job = attrib(type=TrainsJob, default=None)
skip_job = attrib(type=bool, default=False)
def __init__(
self,
@ -78,9 +80,11 @@ class PipelineController(object):
self._thread = None
self._stop_event = None
self._experiment_created_cb = None
self._experiment_completed_cb = None
self._add_pipeline_tags = add_pipeline_tags
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
self._step_ref_pattern = re.compile(self._step_pattern)
self._reporting_lock = RLock()
if not self._task and always_create_task:
self._task = Task.init(
project_name='Pipelines',
@ -168,8 +172,13 @@ class PipelineController(object):
return True
def start(self, run_remotely=False, step_task_created_callback=None):
# type: (Union[bool, str], Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
def start(
self,
run_remotely=False, # type: Union[bool, str]
step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
step_task_completed_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
):
# type: (...) -> bool
"""
Start the pipeline controller.
If the calling process is stopped, then the controller stops as well.
@ -181,17 +190,34 @@ class PipelineController(object):
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the TrainsJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the TrainsJob.
If the callback returned value is `False`,
the Node is skipped and so is any node in the DAG that relies on this node.
Notice the `parameters` are already parsed,
e.g. `${step1.parameters.Args/param}` is replaced with relevant value.
.. code-block:: py
def step_created_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
parameters, # type: dict
):
pass
:param Callable step_task_completed_callback: Callback function, called when a step (Task) is completed
and it other jobs are executed. Allows a user to modify the Task status after completion.
.. code-block:: py
def step_completed_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
):
pass
:return: True, if the controller started. False, if the controller did not start.
"""
@ -225,7 +251,8 @@ class PipelineController(object):
raise ValueError("Failed verifying pipeline execution graph, "
"it has either inaccessible nodes, or contains cycles")
self._update_execution_plot()
self.update_execution_plot()
print('update_execution_plot!!!!')
if run_remotely:
self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
@ -234,6 +261,7 @@ class PipelineController(object):
self._start_time = time()
self._stop_event = Event()
self._experiment_created_cb = step_task_created_callback
self._experiment_completed_cb = step_task_completed_callback
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
@ -451,11 +479,30 @@ class PipelineController(object):
parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task,
)
skip_node = None
if self._experiment_created_cb:
self._experiment_created_cb(node, updated_hyper_parameters)
node.job.launch(queue_name=node.queue or self._default_execution_queue)
skip_node = self._experiment_created_cb(self, node, updated_hyper_parameters)
if skip_node is False:
# skipping node
getLogger('clearml.automation.controller').warning(
'Skipping node {} on callback request'.format(node))
# delete the job we just created
node.job.delete()
node.skip_job = True
else:
node.job.launch(queue_name=node.queue or self._default_execution_queue)
return True
def update_execution_plot(self):
# type: () -> ()
"""
Update sankey diagram of the current pipeline
"""
with self._reporting_lock:
self._update_execution_plot()
def _update_execution_plot(self):
# type: () -> ()
"""
@ -498,7 +545,8 @@ class PipelineController(object):
sankey_node['color'].append(
("red" if node.job and node.job.is_failed() else
("blue" if not node.job or node.job.is_completed() else "royalblue"))
if node.executed is not None else ("green" if node.job else "lightsteelblue"))
if node.executed is not None else
("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
for p in parents:
sankey_link['source'].append(p)
@ -525,29 +573,11 @@ class PipelineController(object):
.replace('/{}/'.format(self._task.project), '/{project}/')\
.replace('/{}/'.format(self._task.id), '/{task}/')
# create the detailed parameter table
def create_task_link(a_node):
task_id = project_id = None
if a_node.job:
project_id = a_node.job.task.project
task_id = a_node.job.task.id
elif a_node.executed:
task_id = a_node.executed
# noinspection PyBroadException
try:
project_id = Task.get_task(task_id=task_id).project
except Exception:
project_id = '*'
if not task_id:
return ''
return '<a href="{}"> {} </a>'.format(
task_link_template.format(project=project_id, task=task_id), task_id)
table_values = [["Pipeline Step", "Task ID", "Parameters"]]
table_values = [["Pipeline Step", "Task ID", "Status", "Parameters"]]
table_values += [
[v, create_task_link(self._nodes[v]), str(n)] for v, n in zip(visited, node_params)]
[v, self.__create_task_link(self._nodes[v], task_link_template),
self.__get_node_status(self._nodes[v]), str(n)]
for v, n in zip(visited, node_params)]
# hack, show single node sankey
if single_nodes:
@ -635,11 +665,16 @@ class PipelineController(object):
if not completed_jobs and self._running_nodes:
continue
# callback on completed jobs
if self._experiment_completed_cb:
for job in completed_jobs:
self._experiment_completed_cb(self, job)
# Pull the next jobs in the pipeline, based on the completed list
next_nodes = []
for node in self._nodes.values():
# check if already processed.
if node.job or node.executed:
# check if already processed or needs to be skipped
if node.job or node.executed or node.skip_job:
continue
completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []]
if all(completed_parents):
@ -652,14 +687,14 @@ class PipelineController(object):
print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override))
self._running_nodes.append(name)
else:
getLogger('clearml.automation.controller').error(
'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name]))
getLogger('clearml.automation.controller').warning(
'Skipping launching step \'{}\': {}'.format(name, self._nodes[name]))
# update current state (in configuration, so that we could later continue an aborted pipeline)
self._force_task_configuration_update()
# visualize pipeline state (plot)
self._update_execution_plot()
self.update_execution_plot()
# quit if all pipelines nodes are fully executed.
if not next_nodes and not self._running_nodes:
@ -672,10 +707,16 @@ class PipelineController(object):
failing_pipeline = True
if node.job and node.executed and not node.job.is_stopped():
node.job.abort()
elif not node.job and not node.executed:
# mark Node as skipped if it has no Job object and it is not executed
node.skip_job = True
if failing_pipeline and self._task:
self._task.mark_failed(status_reason='Pipeline step failed')
# visualize pipeline state (plot)
self.update_execution_plot()
if self._stop_event:
# noinspection PyBroadException
try:
@ -822,3 +863,43 @@ class PipelineController(object):
new_val = self.__parse_step_reference(g)
updated_value = updated_value.replace(g, new_val, 1)
return updated_value
@classmethod
def __get_node_status(cls, a_node):
# type: (PipelineController.Node) -> str
if not a_node:
return "pending"
if a_node.skip_job:
return "skipped"
if a_node.job and a_node.job.task:
return str(a_node.job.task.data.status)
if a_node.job and a_node.job.executed:
return "executed"
return "pending"
@classmethod
def __create_task_link(cls, a_node, task_link_template):
# type: (PipelineController.Node, str) -> str
if not a_node:
return ''
# create the detailed parameter table
task_id = project_id = None
if a_node.job:
project_id = a_node.job.task.project
task_id = a_node.job.task.id
elif a_node.executed:
task_id = a_node.executed
if cls._task_project_lookup.get(task_id):
project_id = cls._task_project_lookup[task_id]
else:
# noinspection PyBroadException
try:
project_id = Task.get_task(task_id=task_id).project
except Exception:
project_id = '*'
cls._task_project_lookup[task_id] = project_id
if not task_id:
return ''
return '<a href="{}"> {} </a>'.format(task_link_template.format(project=project_id, task=task_id), task_id)

View File

@ -282,6 +282,21 @@ class TrainsJob(object):
self.task_started = True
return True
def delete(self):
# type: () -> bool
"""
Delete the current temporary job (before launching)
Return False if the Job/Task could not deleted
"""
if not self.task:
return False
if self.task.delete():
self.task = None
return True
return False
# noinspection PyMethodMayBeStatic, PyUnusedLocal
class _JobStub(object):