Add Pipeline controller caching, improve pipeline plot reporting

This commit is contained in:
allegroai 2021-04-25 10:43:39 +03:00
parent a9f52a468c
commit 9d108d855f

View File

@ -6,7 +6,7 @@ from threading import Thread, Event, RLock
from time import time
from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union
from typing import Sequence, Optional, Mapping, Callable, Any, Union, List
from ..backend_interface.util import get_or_create_project
from ..debugging.log import LoggerRoot
@ -99,6 +99,7 @@ class PipelineController(object):
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
self._step_ref_pattern = re.compile(self._step_pattern)
self._reporting_lock = RLock()
self._pipeline_task_status_failed = None
if not self._task and always_create_task:
self._task = Task.init(
project_name=pipeline_project or 'Pipelines',
@ -381,7 +382,11 @@ class PipelineController(object):
:param float timeout: Wait timeout for the optimization thread to exit (minutes).
The default is ``None``, indicating do not wait terminate immediately.
"""
pass
self.wait(timeout=timeout)
if self._task and self._pipeline_task_status_failed:
print('Setting pipeline controller Task as failed (due to failed steps) !')
self._task.close()
self._task.mark_failed(status_reason='Pipeline step failed', force=True)
def wait(self, timeout=None):
# type: (Optional[float]) -> bool
@ -418,7 +423,16 @@ class PipelineController(object):
:return: A boolean indicating whether the pipeline controller is active (still running) or stopped.
"""
return self._thread is not None
return self._thread is not None and self._thread.is_alive()
def is_successful(self):
# type: () -> bool
"""
return True if the pipeline controller is fully executed and none of the steps / Tasks failed
:return: A boolean indicating whether all steps did not fail
"""
return self._thread and not self.is_running() and not self._pipeline_task_status_failed
def elapsed(self):
# type: () -> float
@ -469,6 +483,14 @@ class PipelineController(object):
"""
return {k: n for k, n in self._nodes.items() if k in self._running_nodes}
def update_execution_plot(self):
# type: () -> ()
"""
Update sankey diagram of the current pipeline
"""
with self._reporting_lock:
self._update_execution_plot()
def _serialize_pipeline_task(self):
# type: () -> (dict, dict)
"""
@ -645,14 +667,6 @@ class PipelineController(object):
return True
def update_execution_plot(self):
# type: () -> ()
"""
Update sankey diagram of the current pipeline
"""
with self._reporting_lock:
self._update_execution_plot()
def _update_execution_plot(self):
# type: () -> ()
"""
@ -719,15 +733,7 @@ class PipelineController(object):
orientation='h'
)
task_link_template = self._task.get_output_log_web_page()\
.replace('/{}/'.format(self._task.project), '/{project}/')\
.replace('/{}/'.format(self._task.id), '/{task}/')
table_values = [["Pipeline Step", "Task ID", "Status", "Parameters"]]
table_values += [
[v, self.__create_task_link(self._nodes[v], task_link_template),
self.__get_node_status(self._nodes[v]), str(n)]
for v, n in zip(visited, node_params)]
table_values = self._build_table_report(node_params, visited)
# hack, show single node sankey
if single_nodes:
@ -766,6 +772,42 @@ class PipelineController(object):
self._task.get_logger().report_table(
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
def _build_table_report(self, node_params, visited):
# type: (List, List) -> List[List]
"""
Create the detailed table report on all the jobs in the pipeline
:param node_params: list of node parameters
:param visited: list of nodes
:return: Table as List of List of strings (cell)
"""
task_link_template = self._task.get_output_log_web_page() \
.replace('/{}/'.format(self._task.project), '/{project}/') \
.replace('/{}/'.format(self._task.id), '/{task}/')
table_values = [["Pipeline Step", "Task ID", "Task Name", "Status", "Parameters"]]
for name, param in zip(visited, node_params):
param_str = str(param)
if len(param_str) > 3:
# remove {} from string
param_str = param_str[1:-1]
step_name = name
if self._nodes[name].base_task_id:
step_name += '\n[<a href="{}"> {} </a>]'.format(
task_link_template.format(project='*', task=self._nodes[name].base_task_id), 'base task')
table_values.append(
[step_name,
self.__create_task_link(self._nodes[name], task_link_template),
self._nodes[name].job.task.name if self._nodes[name].job else '',
self.__get_node_status(self._nodes[name]),
param_str]
)
return table_values
@staticmethod
def _get_node_color(node):
# type (self.Mode) -> str
@ -788,7 +830,7 @@ class PipelineController(object):
return "royalblue" # aborted job
elif node.job:
if node.job.is_pending():
return "mediumseagreen" # pending in queue
return "#bdf5bd" # lightgreen, pending in queue
else:
return "green" # running job
elif node.skip_job:
@ -810,10 +852,11 @@ class PipelineController(object):
:return:
"""
pooling_counter = 0
launched_nodes = set()
last_plot_report = time()
while self._stop_event:
# stop request
if pooling_counter and self._stop_event.wait(self._pool_frequency):
if self._stop_event.wait(self._pool_frequency if pooling_counter else 0.01):
break
pooling_counter += 1
@ -825,6 +868,7 @@ class PipelineController(object):
# check the state of all current jobs
# if no a job ended, continue
completed_jobs = []
force_execution_plot_update = False
for j in self._running_nodes:
node = self._nodes[j]
if not node.job:
@ -832,18 +876,29 @@ class PipelineController(object):
if node.job.is_stopped():
completed_jobs.append(j)
node.executed = node.job.task_id() if not node.job.is_failed() else False
if j in launched_nodes:
launched_nodes.remove(j)
elif node.timeout:
started = node.job.task.data.started
if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
node.job.abort()
completed_jobs.append(j)
node.executed = node.job.task_id()
elif j in launched_nodes and node.job.is_running():
# make sure update the execution graph when the job started running
# (otherwise it will still be marked queued)
launched_nodes.remove(j)
force_execution_plot_update = True
# update running jobs
self._running_nodes = [j for j in self._running_nodes if j not in completed_jobs]
# nothing changed, we can sleep
if not completed_jobs and self._running_nodes:
# force updating the pipeline state (plot) at least every 5 min.
if force_execution_plot_update or time()-last_plot_report > 5.*60:
last_plot_report = time()
self.update_execution_plot()
continue
# callback on completed jobs
@ -873,6 +928,10 @@ class PipelineController(object):
print('Launching step: {}'.format(name))
print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override))
self._running_nodes.append(name)
launched_nodes.add(name)
# check if node is cached do not wait for event but run the loop again
if self._nodes[name].executed:
pooling_counter = 0
else:
getLogger('clearml.automation.controller').warning(
'Skipping launching step \'{}\': {}'.format(name, self._nodes[name]))
@ -888,19 +947,15 @@ class PipelineController(object):
break
# stop all currently running jobs:
failing_pipeline = False
for node in self._nodes.values():
if node.executed is False:
failing_pipeline = True
self._pipeline_task_status_failed = True
if node.job and node.executed and not node.job.is_stopped():
node.job.abort()
elif not node.job and not node.executed:
# mark Node as skipped if it has no Job object and it is not executed
node.skip_job = True
if failing_pipeline and self._task:
self._task.mark_failed(status_reason='Pipeline step failed')
# visualize pipeline state (plot)
self.update_execution_plot()
@ -911,6 +966,36 @@ class PipelineController(object):
except Exception:
pass
def _parse_step_ref(self, value):
# type: (Any) -> Optional[str]
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param value: string
:return:
"""
# look for all the step references
pattern = self._step_ref_pattern
updated_value = value
if isinstance(value, str):
for g in pattern.findall(value):
# update with actual value
new_val = self.__parse_step_reference(g)
updated_value = updated_value.replace(g, new_val, 1)
return updated_value
def _parse_task_overrides(self, task_overrides):
# type: (dict) -> dict
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param task_overrides: string
:return:
"""
updated_overrides = {}
for k, v in task_overrides.items():
updated_overrides[k] = self._parse_step_ref(v)
return updated_overrides
def __verify_step_reference(self, node, step_ref_string):
# type: (PipelineController.Node, str) -> bool
"""
@ -1047,36 +1132,6 @@ class PipelineController(object):
return None
def _parse_step_ref(self, value):
# type: (Any) -> Optional[str]
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param value: string
:return:
"""
# look for all the step references
pattern = self._step_ref_pattern
updated_value = value
if isinstance(value, str):
for g in pattern.findall(value):
# update with actual value
new_val = self.__parse_step_reference(g)
updated_value = updated_value.replace(g, new_val, 1)
return updated_value
def _parse_task_overrides(self, task_overrides):
# type: (dict) -> dict
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param task_overrides: string
:return:
"""
updated_overrides = {}
for k, v in task_overrides.items():
updated_overrides[k] = self._parse_step_ref(v)
return updated_overrides
@classmethod
def __get_node_status(cls, a_node):
# type: (PipelineController.Node) -> str