diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index a7fb3b9a..b336d1de 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -62,6 +62,7 @@ class PipelineController(object): _add_to_evaluated_return_values = {} # TID: bool _retries = {} # Node.name: int _retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa + _status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str] _final_failure = {} # Node.name: bool _task_template_header = CreateFromFunction.default_task_template_header @@ -331,6 +332,7 @@ class PipelineController(object): cache_executed_step=False, # type: bool base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]] retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa + status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa ): # type: (...) -> bool """ @@ -466,6 +468,20 @@ class PipelineController(object): # allow up to 5 retries (total of 6 runs) return retries < 5 + :param status_change_callback: Callback function, called when the status of a step (Task) changes. + Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. + The signature of the function must look the following way: + + .. code-block:: py + + def status_change_callback( + pipeline, # type: PipelineController, + node, # type: PipelineController.Node, + previous_status # type: str + ): + pass + + :return: True if successful """ # always store callback functions (even when running remotely) @@ -527,6 +543,8 @@ class PipelineController(object): self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \ (functools.partial(self._default_retry_on_failure_callback, max_retries=retry_on_failure) if isinstance(retry_on_failure, int) else self._retry_on_failure_callback) + if status_change_callback: + self._status_change_callbacks[name] = status_change_callback if self._task and not self._task.running_locally(): self.update_execution_plot() @@ -563,6 +581,7 @@ class PipelineController(object): post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa cache_executed_step=False, # type: bool retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa + status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa ): # type: (...) -> bool """ @@ -713,6 +732,19 @@ class PipelineController(object): # allow up to 5 retries (total of 6 runs) return retries < 5 + :param status_change_callback: Callback function, called when the status of a step (Task) changes. + Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. + The signature of the function must look the following way: + + .. code-block:: py + + def status_change_callback( + pipeline, # type: PipelineController, + node, # type: PipelineController.Node, + previous_status # type: str + ): + pass + :return: True if successful """ function_kwargs = function_kwargs or {} @@ -750,6 +782,7 @@ class PipelineController(object): post_execute_callback=post_execute_callback, cache_executed_step=cache_executed_step, retry_on_failure=retry_on_failure, + status_change_callback=status_change_callback ) def start( @@ -1652,6 +1685,7 @@ class PipelineController(object): post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa cache_executed_step=False, # type: bool retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa + status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa ): # type: (...) -> bool """ @@ -1802,6 +1836,19 @@ class PipelineController(object): # allow up to 5 retries (total of 6 runs) return retries < 5 + :param status_change_callback: Callback function, called when the status of a step (Task) changes. + Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. + The signature of the function must look the following way: + + .. code-block:: py + + def status_change_callback( + pipeline, # type: PipelineController, + node, # type: PipelineController.Node, + previous_status # type: str + ): + pass + :return: True if successful """ # always store callback functions (even when running remotely) @@ -1809,6 +1856,8 @@ class PipelineController(object): self._pre_step_callbacks[name] = pre_execute_callback if post_execute_callback: self._post_step_callbacks[name] = post_execute_callback + if status_change_callback: + self._status_change_callbacks[name] = status_change_callback self._verify_node_name(name) @@ -2025,9 +2074,7 @@ class PipelineController(object): return nodes = list(self._nodes.values()) - # update status - for n in nodes: - self._update_node_status(n) + self._update_nodes_status() # update the configuration state, so that the UI is presents the correct state self._force_task_configuration_update() @@ -2218,23 +2265,33 @@ class PipelineController(object): } return color_lookup.get(node.status, "") - @classmethod - def _update_node_status(cls, node): - # type (self.Mode) -> () + def _update_nodes_status(self): + # type () -> () + """ + Update the status of all nodes in the pipeline + """ + jobs = [] + previous_status_map = {} + for name, node in self._nodes.items(): + if not node.job: + continue + # noinspection PyProtectedMember + previous_status_map[name] = node.job._last_status + jobs.append(node.job) + BaseJob.update_status_batch(jobs) + for node in self._nodes.values(): + self._update_node_status(node) + + def _update_node_status(self, node): + # type (self.Node) -> () """ Update the node status entry based on the node/job state :param node: A node in the pipeline """ - if not node: - return + previous_status = node.status - # update job ended: update_job_ended = node.job_started and not node.job_ended - # refresh status - if node.job and isinstance(node.job, BaseJob): - node.job.status(force=True) - if node.executed is not None: if node.job and node.job.is_failed(): # failed job @@ -2271,7 +2328,18 @@ class PipelineController(object): if update_job_ended and node.status in ("aborted", "failed", "completed"): node.job_ended = time() - assert node.status in cls.valid_job_status + if ( + previous_status is not None + and previous_status != node.status + and self._status_change_callbacks.get(node.name) + ): + # noinspection PyBroadException + try: + self._status_change_callbacks[node.name](self, node, previous_status) + except Exception as e: + getLogger("clearml.automation.controller").warning( + "Failed calling the status change callback for node '{}'. Error is '{}'".format(node.name, e) + ) def _update_dag_state_artifact(self): # type: () -> () @@ -2326,6 +2394,7 @@ class PipelineController(object): break self._update_progress() + self._update_nodes_status() # check the state of all current jobs # if no a job ended, continue completed_jobs = [] @@ -3030,6 +3099,7 @@ class PipelineDecorator(PipelineController): break self._update_progress() + self._update_nodes_status() # check the state of all current jobs # if no a job ended, continue completed_jobs = [] @@ -3322,7 +3392,8 @@ class PipelineDecorator(PipelineController): monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]] retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa - post_execute_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa + post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa + status_change_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa ): # type: (...) -> Callable """ @@ -3444,6 +3515,19 @@ class PipelineDecorator(PipelineController): ): pass + :param status_change_callback: Callback function, called when the status of a step (Task) changes. + Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. + The signature of the function must look the following way: + + .. code-block:: py + + def status_change_callback( + pipeline, # type: PipelineController, + node, # type: PipelineController.Node, + previous_status # type: str + ): + pass + :return: function wrapper """ @@ -3484,7 +3568,8 @@ class PipelineDecorator(PipelineController): monitor_models=monitor_models, monitor_artifacts=monitor_artifacts, pre_execute_callback=pre_execute_callback, - post_execute_callback=post_execute_callback + post_execute_callback=post_execute_callback, + status_change_callback=status_change_callback ) if cls._singleton: diff --git a/clearml/automation/job.py b/clearml/automation/job.py index 1a62bcbc..d24d0ec2 100644 --- a/clearml/automation/job.py +++ b/clearml/automation/job.py @@ -28,6 +28,7 @@ class BaseJob(object): _job_hash_description = 'job_hash={}' _job_hash_property = 'pipeline_job_hash' _hashing_callback = None + _last_batch_status_update_ts = 0 def __init__(self): # type: () -> () @@ -174,6 +175,36 @@ class BaseJob(object): """ return str(self.task.data.status_message) + @classmethod + def update_status_batch(cls, jobs): + # type: (Sequence[BaseJob]) -> () + """ + Update the status of jobs, in batch_size + + :param jobs: The jobs to update the status of + """ + have_job_with_no_status = False + id_map = {} + for job in jobs: + if not job.task: + continue + id_map[job.task.id] = job + # noinspection PyProtectedMember + if not job._last_status: + have_job_with_no_status = True + if not id_map or (time() - cls._last_batch_status_update_ts < 1 and not have_job_with_no_status): + return + batch_status = Task._get_status(list(id_map.keys())) + last_batch_update_ts = time() + cls._last_batch_status_update_ts = last_batch_update_ts + for status in batch_status: + if not status.status or not status.id: + continue + # noinspection PyProtectedMember + id_map[status.id]._last_status = status.status + # noinspection PyProtectedMember + id_map[status.id]._last_status_ts = last_batch_update_ts + def wait(self, timeout=None, pool_period=30., aborted_nonresponsive_as_running=False): # type: (Optional[float], float, bool) -> bool """