Add status change callback to pipelines

This commit is contained in:
allegroai 2023-02-28 17:03:48 +02:00
parent 6b32e1d33a
commit 89e39f0b02
2 changed files with 132 additions and 16 deletions

View File

@ -62,6 +62,7 @@ class PipelineController(object):
_add_to_evaluated_return_values = {} # TID: bool
_retries = {} # Node.name: int
_retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa
_status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str]
_final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header
@ -331,6 +332,7 @@ class PipelineController(object):
cache_executed_step=False, # type: bool
base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]]
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
):
# type: (...) -> bool
"""
@ -466,6 +468,20 @@ class PipelineController(object):
# allow up to 5 retries (total of 6 runs)
return retries < 5
:param status_change_callback: Callback function, called when the status of a step (Task) changes.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
The signature of the function must look the following way:
.. code-block:: py
def status_change_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
previous_status # type: str
):
pass
:return: True if successful
"""
# always store callback functions (even when running remotely)
@ -527,6 +543,8 @@ class PipelineController(object):
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
(functools.partial(self._default_retry_on_failure_callback, max_retries=retry_on_failure)
if isinstance(retry_on_failure, int) else self._retry_on_failure_callback)
if status_change_callback:
self._status_change_callbacks[name] = status_change_callback
if self._task and not self._task.running_locally():
self.update_execution_plot()
@ -563,6 +581,7 @@ class PipelineController(object):
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
):
# type: (...) -> bool
"""
@ -713,6 +732,19 @@ class PipelineController(object):
# allow up to 5 retries (total of 6 runs)
return retries < 5
:param status_change_callback: Callback function, called when the status of a step (Task) changes.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
The signature of the function must look the following way:
.. code-block:: py
def status_change_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
previous_status # type: str
):
pass
:return: True if successful
"""
function_kwargs = function_kwargs or {}
@ -750,6 +782,7 @@ class PipelineController(object):
post_execute_callback=post_execute_callback,
cache_executed_step=cache_executed_step,
retry_on_failure=retry_on_failure,
status_change_callback=status_change_callback
)
def start(
@ -1652,6 +1685,7 @@ class PipelineController(object):
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
):
# type: (...) -> bool
"""
@ -1802,6 +1836,19 @@ class PipelineController(object):
# allow up to 5 retries (total of 6 runs)
return retries < 5
:param status_change_callback: Callback function, called when the status of a step (Task) changes.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
The signature of the function must look the following way:
.. code-block:: py
def status_change_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
previous_status # type: str
):
pass
:return: True if successful
"""
# always store callback functions (even when running remotely)
@ -1809,6 +1856,8 @@ class PipelineController(object):
self._pre_step_callbacks[name] = pre_execute_callback
if post_execute_callback:
self._post_step_callbacks[name] = post_execute_callback
if status_change_callback:
self._status_change_callbacks[name] = status_change_callback
self._verify_node_name(name)
@ -2025,9 +2074,7 @@ class PipelineController(object):
return
nodes = list(self._nodes.values())
# update status
for n in nodes:
self._update_node_status(n)
self._update_nodes_status()
# update the configuration state, so that the UI is presents the correct state
self._force_task_configuration_update()
@ -2218,23 +2265,33 @@ class PipelineController(object):
}
return color_lookup.get(node.status, "")
@classmethod
def _update_node_status(cls, node):
# type (self.Mode) -> ()
def _update_nodes_status(self):
# type () -> ()
"""
Update the status of all nodes in the pipeline
"""
jobs = []
previous_status_map = {}
for name, node in self._nodes.items():
if not node.job:
continue
# noinspection PyProtectedMember
previous_status_map[name] = node.job._last_status
jobs.append(node.job)
BaseJob.update_status_batch(jobs)
for node in self._nodes.values():
self._update_node_status(node)
def _update_node_status(self, node):
# type (self.Node) -> ()
"""
Update the node status entry based on the node/job state
:param node: A node in the pipeline
"""
if not node:
return
previous_status = node.status
# update job ended:
update_job_ended = node.job_started and not node.job_ended
# refresh status
if node.job and isinstance(node.job, BaseJob):
node.job.status(force=True)
if node.executed is not None:
if node.job and node.job.is_failed():
# failed job
@ -2271,7 +2328,18 @@ class PipelineController(object):
if update_job_ended and node.status in ("aborted", "failed", "completed"):
node.job_ended = time()
assert node.status in cls.valid_job_status
if (
previous_status is not None
and previous_status != node.status
and self._status_change_callbacks.get(node.name)
):
# noinspection PyBroadException
try:
self._status_change_callbacks[node.name](self, node, previous_status)
except Exception as e:
getLogger("clearml.automation.controller").warning(
"Failed calling the status change callback for node '{}'. Error is '{}'".format(node.name, e)
)
def _update_dag_state_artifact(self):
# type: () -> ()
@ -2326,6 +2394,7 @@ class PipelineController(object):
break
self._update_progress()
self._update_nodes_status()
# check the state of all current jobs
# if no a job ended, continue
completed_jobs = []
@ -3030,6 +3099,7 @@ class PipelineDecorator(PipelineController):
break
self._update_progress()
self._update_nodes_status()
# check the state of all current jobs
# if no a job ended, continue
completed_jobs = []
@ -3322,7 +3392,8 @@ class PipelineDecorator(PipelineController):
monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
):
# type: (...) -> Callable
"""
@ -3444,6 +3515,19 @@ class PipelineDecorator(PipelineController):
):
pass
:param status_change_callback: Callback function, called when the status of a step (Task) changes.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
The signature of the function must look the following way:
.. code-block:: py
def status_change_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
previous_status # type: str
):
pass
:return: function wrapper
"""
@ -3484,7 +3568,8 @@ class PipelineDecorator(PipelineController):
monitor_models=monitor_models,
monitor_artifacts=monitor_artifacts,
pre_execute_callback=pre_execute_callback,
post_execute_callback=post_execute_callback
post_execute_callback=post_execute_callback,
status_change_callback=status_change_callback
)
if cls._singleton:

View File

@ -28,6 +28,7 @@ class BaseJob(object):
_job_hash_description = 'job_hash={}'
_job_hash_property = 'pipeline_job_hash'
_hashing_callback = None
_last_batch_status_update_ts = 0
def __init__(self):
# type: () -> ()
@ -174,6 +175,36 @@ class BaseJob(object):
"""
return str(self.task.data.status_message)
@classmethod
def update_status_batch(cls, jobs):
# type: (Sequence[BaseJob]) -> ()
"""
Update the status of jobs, in batch_size
:param jobs: The jobs to update the status of
"""
have_job_with_no_status = False
id_map = {}
for job in jobs:
if not job.task:
continue
id_map[job.task.id] = job
# noinspection PyProtectedMember
if not job._last_status:
have_job_with_no_status = True
if not id_map or (time() - cls._last_batch_status_update_ts < 1 and not have_job_with_no_status):
return
batch_status = Task._get_status(list(id_map.keys()))
last_batch_update_ts = time()
cls._last_batch_status_update_ts = last_batch_update_ts
for status in batch_status:
if not status.status or not status.id:
continue
# noinspection PyProtectedMember
id_map[status.id]._last_status = status.status
# noinspection PyProtectedMember
id_map[status.id]._last_status_ts = last_batch_update_ts
def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False):
# type: (Optional[float], float, bool) -> bool
"""