diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py
index 108d47ef..60bdbaa7 100644
--- a/clearml/automation/controller.py
+++ b/clearml/automation/controller.py
@@ -2,7 +2,7 @@ import re
from copy import copy
from datetime import datetime
from logging import getLogger
-from threading import Thread, Event
+from threading import Thread, Event, RLock
from time import time
from attr import attrib, attrs
@@ -24,6 +24,7 @@ class PipelineController(object):
_tag = 'pipeline'
_step_pattern = r"\${[^}]*}"
_config_section = 'Pipeline'
+ _task_project_lookup = {}
@attrs
class Node(object):
@@ -36,6 +37,7 @@ class PipelineController(object):
executed = attrib(type=str, default=None)
clone_task = attrib(type=bool, default=True)
job = attrib(type=TrainsJob, default=None)
+ skip_job = attrib(type=bool, default=False)
def __init__(
self,
@@ -78,9 +80,11 @@ class PipelineController(object):
self._thread = None
self._stop_event = None
self._experiment_created_cb = None
+ self._experiment_completed_cb = None
self._add_pipeline_tags = add_pipeline_tags
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
self._step_ref_pattern = re.compile(self._step_pattern)
+ self._reporting_lock = RLock()
if not self._task and always_create_task:
self._task = Task.init(
project_name='Pipelines',
@@ -168,8 +172,13 @@ class PipelineController(object):
return True
- def start(self, run_remotely=False, step_task_created_callback=None):
- # type: (Union[bool, str], Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
+ def start(
+ self,
+ run_remotely=False, # type: Union[bool, str]
+ step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
+ step_task_completed_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
+ ):
+ # type: (...) -> bool
"""
Start the pipeline controller.
If the calling process is stopped, then the controller stops as well.
@@ -181,17 +190,34 @@ class PipelineController(object):
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the TrainsJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the TrainsJob.
+
+ If the callback returned value is `False`,
+ the Node is skipped and so is any node in the DAG that relies on this node.
+
Notice the `parameters` are already parsed,
e.g. `${step1.parameters.Args/param}` is replaced with relevant value.
.. code-block:: py
def step_created_callback(
+ pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
parameters, # type: dict
):
pass
+ :param Callable step_task_completed_callback: Callback function, called when a step (Task) is completed
+ and it other jobs are executed. Allows a user to modify the Task status after completion.
+
+ .. code-block:: py
+
+ def step_completed_callback(
+ pipeline, # type: PipelineController,
+ node, # type: PipelineController.Node,
+ ):
+ pass
+
+
:return: True, if the controller started. False, if the controller did not start.
"""
@@ -225,7 +251,8 @@ class PipelineController(object):
raise ValueError("Failed verifying pipeline execution graph, "
"it has either inaccessible nodes, or contains cycles")
- self._update_execution_plot()
+ self.update_execution_plot()
+ print('update_execution_plot!!!!')
if run_remotely:
self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
@@ -234,6 +261,7 @@ class PipelineController(object):
self._start_time = time()
self._stop_event = Event()
self._experiment_created_cb = step_task_created_callback
+ self._experiment_completed_cb = step_task_completed_callback
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
@@ -451,11 +479,30 @@ class PipelineController(object):
parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task,
)
+ skip_node = None
if self._experiment_created_cb:
- self._experiment_created_cb(node, updated_hyper_parameters)
- node.job.launch(queue_name=node.queue or self._default_execution_queue)
+ skip_node = self._experiment_created_cb(self, node, updated_hyper_parameters)
+
+ if skip_node is False:
+ # skipping node
+ getLogger('clearml.automation.controller').warning(
+ 'Skipping node {} on callback request'.format(node))
+ # delete the job we just created
+ node.job.delete()
+ node.skip_job = True
+ else:
+ node.job.launch(queue_name=node.queue or self._default_execution_queue)
+
return True
+ def update_execution_plot(self):
+ # type: () -> ()
+ """
+ Update sankey diagram of the current pipeline
+ """
+ with self._reporting_lock:
+ self._update_execution_plot()
+
def _update_execution_plot(self):
# type: () -> ()
"""
@@ -498,7 +545,8 @@ class PipelineController(object):
sankey_node['color'].append(
("red" if node.job and node.job.is_failed() else
("blue" if not node.job or node.job.is_completed() else "royalblue"))
- if node.executed is not None else ("green" if node.job else "lightsteelblue"))
+ if node.executed is not None else
+ ("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
for p in parents:
sankey_link['source'].append(p)
@@ -525,29 +573,11 @@ class PipelineController(object):
.replace('/{}/'.format(self._task.project), '/{project}/')\
.replace('/{}/'.format(self._task.id), '/{task}/')
- # create the detailed parameter table
- def create_task_link(a_node):
- task_id = project_id = None
- if a_node.job:
- project_id = a_node.job.task.project
- task_id = a_node.job.task.id
- elif a_node.executed:
- task_id = a_node.executed
- # noinspection PyBroadException
- try:
- project_id = Task.get_task(task_id=task_id).project
- except Exception:
- project_id = '*'
-
- if not task_id:
- return ''
-
- return ' {} '.format(
- task_link_template.format(project=project_id, task=task_id), task_id)
-
- table_values = [["Pipeline Step", "Task ID", "Parameters"]]
+ table_values = [["Pipeline Step", "Task ID", "Status", "Parameters"]]
table_values += [
- [v, create_task_link(self._nodes[v]), str(n)] for v, n in zip(visited, node_params)]
+ [v, self.__create_task_link(self._nodes[v], task_link_template),
+ self.__get_node_status(self._nodes[v]), str(n)]
+ for v, n in zip(visited, node_params)]
# hack, show single node sankey
if single_nodes:
@@ -635,11 +665,16 @@ class PipelineController(object):
if not completed_jobs and self._running_nodes:
continue
+ # callback on completed jobs
+ if self._experiment_completed_cb:
+ for job in completed_jobs:
+ self._experiment_completed_cb(self, job)
+
# Pull the next jobs in the pipeline, based on the completed list
next_nodes = []
for node in self._nodes.values():
- # check if already processed.
- if node.job or node.executed:
+ # check if already processed or needs to be skipped
+ if node.job or node.executed or node.skip_job:
continue
completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []]
if all(completed_parents):
@@ -652,14 +687,14 @@ class PipelineController(object):
print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override))
self._running_nodes.append(name)
else:
- getLogger('clearml.automation.controller').error(
- 'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name]))
+ getLogger('clearml.automation.controller').warning(
+ 'Skipping launching step \'{}\': {}'.format(name, self._nodes[name]))
# update current state (in configuration, so that we could later continue an aborted pipeline)
self._force_task_configuration_update()
# visualize pipeline state (plot)
- self._update_execution_plot()
+ self.update_execution_plot()
# quit if all pipelines nodes are fully executed.
if not next_nodes and not self._running_nodes:
@@ -672,10 +707,16 @@ class PipelineController(object):
failing_pipeline = True
if node.job and node.executed and not node.job.is_stopped():
node.job.abort()
+ elif not node.job and not node.executed:
+ # mark Node as skipped if it has no Job object and it is not executed
+ node.skip_job = True
if failing_pipeline and self._task:
self._task.mark_failed(status_reason='Pipeline step failed')
+ # visualize pipeline state (plot)
+ self.update_execution_plot()
+
if self._stop_event:
# noinspection PyBroadException
try:
@@ -822,3 +863,43 @@ class PipelineController(object):
new_val = self.__parse_step_reference(g)
updated_value = updated_value.replace(g, new_val, 1)
return updated_value
+
+ @classmethod
+ def __get_node_status(cls, a_node):
+ # type: (PipelineController.Node) -> str
+ if not a_node:
+ return "pending"
+ if a_node.skip_job:
+ return "skipped"
+ if a_node.job and a_node.job.task:
+ return str(a_node.job.task.data.status)
+ if a_node.job and a_node.job.executed:
+ return "executed"
+ return "pending"
+
+ @classmethod
+ def __create_task_link(cls, a_node, task_link_template):
+ # type: (PipelineController.Node, str) -> str
+ if not a_node:
+ return ''
+ # create the detailed parameter table
+ task_id = project_id = None
+ if a_node.job:
+ project_id = a_node.job.task.project
+ task_id = a_node.job.task.id
+ elif a_node.executed:
+ task_id = a_node.executed
+ if cls._task_project_lookup.get(task_id):
+ project_id = cls._task_project_lookup[task_id]
+ else:
+ # noinspection PyBroadException
+ try:
+ project_id = Task.get_task(task_id=task_id).project
+ except Exception:
+ project_id = '*'
+ cls._task_project_lookup[task_id] = project_id
+
+ if not task_id:
+ return ''
+
+ return ' {} '.format(task_link_template.format(project=project_id, task=task_id), task_id)
diff --git a/clearml/automation/job.py b/clearml/automation/job.py
index 4085063e..b6fd9f52 100644
--- a/clearml/automation/job.py
+++ b/clearml/automation/job.py
@@ -282,6 +282,21 @@ class TrainsJob(object):
self.task_started = True
return True
+ def delete(self):
+ # type: () -> bool
+ """
+ Delete the current temporary job (before launching)
+ Return False if the Job/Task could not deleted
+ """
+ if not self.task:
+ return False
+
+ if self.task.delete():
+ self.task = None
+ return True
+
+ return False
+
# noinspection PyMethodMayBeStatic, PyUnusedLocal
class _JobStub(object):