Add pipeline step caching

This commit is contained in:
allegroai 2021-04-25 10:41:51 +03:00
parent 696034ac75
commit f9e555b464
2 changed files with 195 additions and 28 deletions

View File

@ -9,7 +9,6 @@ from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union from typing import Sequence, Optional, Mapping, Callable, Any, Union
from ..backend_interface.util import get_or_create_project from ..backend_interface.util import get_or_create_project
from ..config import get_remote_task_id
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..task import Task from ..task import Task
from ..automation import ClearmlJob from ..automation import ClearmlJob
@ -41,6 +40,7 @@ class PipelineController(object):
clone_task = attrib(type=bool, default=True) clone_task = attrib(type=bool, default=True)
job = attrib(type=ClearmlJob, default=None) job = attrib(type=ClearmlJob, default=None)
skip_job = attrib(type=bool, default=False) skip_job = attrib(type=bool, default=False)
cache_executed_step = attrib(type=bool, default=False)
def __init__( def __init__(
self, self,
@ -125,7 +125,7 @@ class PipelineController(object):
clone_base_task=True, # type: bool clone_base_task=True, # type: bool
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
cache_executed_step=False, # type: bool
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -199,6 +199,13 @@ class PipelineController(object):
): ):
pass pass
:param cache_executed_step: If True, before launching the new step,
after updating with the latest configuration, check if an exact Task with the same parameter/code
was already executed. If it was found, use it instead of launching a new Task.
Default: False, a new cloned copy of base_task is always used.
Notice: If the git repo reference does not have a specific commit ID, the Task will never be used.
If `clone_base_task` is False there is no cloning, hence the base_task is used.
:return: True if successful :return: True if successful
""" """
@ -219,7 +226,17 @@ class PipelineController(object):
if not base_task_id: if not base_task_id:
if not base_task_project or not base_task_name: if not base_task_project or not base_task_name:
raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided') raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name) base_task = Task.get_task(
project_name=base_task_project,
task_name=base_task_name,
allow_archived=True,
task_filter=dict(
status=[str(Task.TaskStatusEnum.created), str(Task.TaskStatusEnum.queued),
str(Task.TaskStatusEnum.in_progress), str(Task.TaskStatusEnum.published),
str(Task.TaskStatusEnum.stopped), str(Task.TaskStatusEnum.completed),
str(Task.TaskStatusEnum.closed)],
)
)
if not base_task: if not base_task:
raise ValueError('Could not find base_task_project={} base_task_name={}'.format( raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
base_task_project, base_task_name)) base_task_project, base_task_name))
@ -235,6 +252,7 @@ class PipelineController(object):
parameters=parameter_override or {}, parameters=parameter_override or {},
clone_task=clone_base_task, clone_task=clone_base_task,
task_overrides=task_overrides, task_overrides=task_overrides,
cache_executed_step=cache_executed_step,
) )
if self._task and not self._task.running_locally(): if self._task and not self._task.running_locally():
@ -606,6 +624,7 @@ class PipelineController(object):
parent=self._task.id if self._task else None, parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task, disable_clone_task=not node.clone_task,
task_overrides=task_overrides, task_overrides=task_overrides,
allow_caching=node.cache_executed_step,
**extra_args **extra_args
) )
@ -619,6 +638,8 @@ class PipelineController(object):
# delete the job we just created # delete the job we just created
node.job.delete() node.job.delete()
node.skip_job = True node.skip_job = True
elif node.job.is_cached_task():
node.executed = node.job.task_id()
else: else:
node.job.launch(queue_name=node.queue or self._default_execution_queue) node.job.launch(queue_name=node.queue or self._default_execution_queue)
@ -674,11 +695,8 @@ class PipelineController(object):
'{}<br />'.format(node.name) + '{}<br />'.format(node.name) +
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...')) '<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
for k, v in (node.parameters or {}).items())) for k, v in (node.parameters or {}).items()))
sankey_node['color'].append(
("red" if node.job and node.job.is_failed() else sankey_node['color'].append(self._get_node_color(node))
("blue" if not node.job or node.job.is_completed() else "royalblue"))
if node.executed is not None else
("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
for p in parents: for p in parents:
sankey_link['source'].append(p) sankey_link['source'].append(p)
@ -748,6 +766,36 @@ class PipelineController(object):
self._task.get_logger().report_table( self._task.get_logger().report_table(
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values) title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
@staticmethod
def _get_node_color(node):
# type (self.Mode) -> str
"""
Return the node color based on the node/job state
:param node: A node in the pipeline
:return: string representing the color of the node (e.g. "red", "green", etc)
"""
if not node:
return ""
if node.executed is not None:
if node.job and node.job.is_failed():
return "red" # failed job
elif node.job and node.job.is_cached_task():
return "darkslateblue"
elif not node.job or node.job.is_completed():
return "blue" # completed job
else:
return "royalblue" # aborted job
elif node.job:
if node.job.is_pending():
return "mediumseagreen" # pending in queue
else:
return "green" # running job
elif node.skip_job:
return "gray" # skipped job
else:
return "lightsteelblue" # pending job
def _force_task_configuration_update(self): def _force_task_configuration_update(self):
pipeline_dag = self._serialize() pipeline_dag = self._serialize()
if self._task: if self._task:
@ -1036,9 +1084,12 @@ class PipelineController(object):
return "pending" return "pending"
if a_node.skip_job: if a_node.skip_job:
return "skipped" return "skipped"
if a_node.job and a_node.job.is_cached_task():
return "cached"
if a_node.job and a_node.job.task: if a_node.job and a_node.job.task:
# no need to refresh status
return str(a_node.job.task.data.status) return str(a_node.job.task.data.status)
if a_node.job and a_node.job.executed: if a_node.executed:
return "executed" return "executed"
return "pending" return "pending"

View File

@ -5,6 +5,7 @@ from logging import getLogger
from time import time, sleep from time import time, sleep
from typing import Optional, Mapping, Sequence, Any from typing import Optional, Mapping, Sequence, Any
from ..storage.util import hash_dict
from ..task import Task from ..task import Task
from ..backend_api.services import tasks as tasks_service from ..backend_api.services import tasks as tasks_service
@ -13,6 +14,8 @@ logger = getLogger('clearml.automation.job')
class ClearmlJob(object): class ClearmlJob(object):
_job_hash_description = 'job_hash={}'
def __init__( def __init__(
self, self,
base_task_id, # type: str base_task_id, # type: str
@ -21,6 +24,7 @@ class ClearmlJob(object):
tags=None, # type: Optional[Sequence[str]] tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str] parent=None, # type: Optional[str]
disable_clone_task=False, # type: bool disable_clone_task=False, # type: bool
allow_caching=False, # type: bool
**kwargs # type: Any **kwargs # type: Any
): ):
# type: (...) -> () # type: (...) -> ()
@ -34,10 +38,13 @@ class ClearmlJob(object):
:param str parent: Set newly created Task parent task field, default: base_tak_id. :param str parent: Set newly created Task parent task field, default: base_tak_id.
:param dict kwargs: additional Task creation parameters :param dict kwargs: additional Task creation parameters
:param bool disable_clone_task: if False (default) clone base task id. :param bool disable_clone_task: if False (default) clone base task id.
If True, use the base_task_id directly (base-task must be in draft-mode / created), If True, use the base_task_id directly (base-task must be in draft-mode / created),
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
If we do, use it and set internal is_cached flag. Default False (always create new Task).
""" """
base_temp_task = Task.get_task(task_id=base_task_id)
if disable_clone_task: if disable_clone_task:
self.task = Task.get_task(task_id=base_task_id) self.task = base_temp_task
task_status = self.task.status task_status = self.task.status
if task_status != Task.TaskStatusEnum.created: if task_status != Task.TaskStatusEnum.created:
logger.warning('Task cloning disabled but requested Task [{}] status={}. ' logger.warning('Task cloning disabled but requested Task [{}] status={}. '
@ -47,31 +54,56 @@ class ClearmlJob(object):
elif parent: elif parent:
self.task.set_parent(parent) self.task.set_parent(parent)
self.task_parameter_override = None
task_params = None
if parameter_override:
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
task_params.update(parameter_override)
self.task_parameter_override = dict(**parameter_override)
sections = {}
if task_overrides:
# set values inside the Task
for k, v in task_overrides.items():
# notice we can allow ourselves to change the base-task object as we will not use it any further
# noinspection PyProtectedMember
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
section = k.split('.')[0]
sections[section] = getattr(base_temp_task.data, section, None)
# check cached task
self._is_cached_task = False
task_hash = None
if allow_caching and not disable_clone_task or not self.task:
# look for a cached copy of the Task
# get parameters + task_overrides + as dict and hash it.
task_hash = self._create_task_hash(
base_temp_task, section_overrides=sections, params_override=task_params)
task = self._get_cached_task(task_hash)
# if we found a task, just use
if task:
self._is_cached_task = True
self.task = task
self.task_started = True
self._worker = None
return
# check again if we need to clone the Task # check again if we need to clone the Task
if not disable_clone_task: if not disable_clone_task:
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs) self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
if tags: if tags:
self.task.set_tags(list(set(self.task.get_tags()) | set(tags))) self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
self.task_parameter_override = None
if parameter_override:
params = self.task.get_parameters(backwards_compatibility=False)
params.update(parameter_override)
self.task.set_parameters(params)
self.task_parameter_override = dict(**parameter_override)
if task_overrides: if task_params:
sections = {} self.task.set_parameters(task_params)
# set values inside the Task
for k, v in task_overrides.items(): if task_overrides and sections:
# noinspection PyProtectedMember
self.task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
section = k.split('.')[0]
sections[section] = getattr(self.task.data, section, None)
# store back Task parameters into backend # store back Task parameters into backend
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.task._edit(**sections) self.task._edit(**sections)
self._set_task_cache_hash(self.task, task_hash)
self.task_started = False self.task_started = False
self._worker = None self._worker = None
@ -108,22 +140,29 @@ class ClearmlJob(object):
return metrics, title, series, values return metrics, title, series, values
def launch(self, queue_name=None): def launch(self, queue_name=None):
# type: (str) -> () # type: (str) -> bool
""" """
Send Job for execution on the requested execution queue Send Job for execution on the requested execution queue
:param str queue_name: :param str queue_name:
:return False if Task is not in "created" status (i.e. cannot be enqueued)
""" """
if self._is_cached_task:
return False
try: try:
Task.enqueue(task=self.task, queue_name=queue_name) Task.enqueue(task=self.task, queue_name=queue_name)
return True
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
return False
def abort(self): def abort(self):
# type: () -> () # type: () -> ()
""" """
Abort currently running job (can be called multiple times) Abort currently running job (can be called multiple times)
""" """
if not self.task or self._is_cached_task:
return
try: try:
self.task.stopped() self.task.stopped()
except Exception as ex: except Exception as ex:
@ -234,7 +273,7 @@ class ClearmlJob(object):
def is_stopped(self): def is_stopped(self):
# type: () -> bool # type: () -> bool
""" """
Return True, if job is has executed and is not any more Return True, if job finished executing (for any reason)
:return: True the task is currently one of these states, stopped / completed / failed / published. :return: True the task is currently one of these states, stopped / completed / failed / published.
""" """
@ -298,7 +337,7 @@ class ClearmlJob(object):
Delete the current temporary job (before launching) Delete the current temporary job (before launching)
Return False if the Job/Task could not deleted Return False if the Job/Task could not deleted
""" """
if not self.task: if not self.task or self._is_cached_task:
return False return False
if self.task.delete(): if self.task.delete():
@ -307,6 +346,82 @@ class ClearmlJob(object):
return False return False
def is_cached_task(self):
# type: () -> bool
"""
:return: True if the internal Task is a cached one, False otherwise.
"""
return self._is_cached_task
@classmethod
def _create_task_hash(cls, task, section_overrides=None, params_override=None):
# type: (Task, Optional[dict], Optional[dict]) -> Optional[str]
"""
Create Hash (str) representing the state of the Task
:param task: A Task to hash
:param section_overrides: optional dict (keys are Task's section names) with task overrides.
:param params_override: Alternative to the entire Task's hyper parameters section
(notice this should not be a nested dict but a flat key/value)
:return: str crc32 of the Task configuration
"""
if not task:
return None
if section_overrides and section_overrides.get('script'):
script = section_overrides['script']
if not isinstance(script, dict):
script = script.to_dict()
else:
script = task.data.script.to_dict() if task.data.script else {}
# if we have a repository, we must make sure we have a specific version_num to ensure consistency
if script.get('repository') and not script.get('version_num') and not script.get('tag'):
return None
# we need to ignore `requirements` section because ir might be changing from run to run
script.pop("requirements", None)
hyper_params = task.get_parameters() if params_override is None else params_override
configs = task.get_configuration_objects()
return hash_dict(dict(script=script, hyper_params=hyper_params, configs=configs), hash_func='crc32')
@classmethod
def _get_cached_task(cls, task_hash):
# type: (str) -> Optional[Task]
"""
:param task_hash:
:return: A task matching the requested task hash
"""
if not task_hash:
return None
# noinspection PyProtectedMember
potential_tasks = Task._query_tasks(
status=['completed', 'stopped', 'published'],
system_tags=['-{}'.format(Task.archived_tag)],
_all_=dict(fields=['comment'], pattern=cls._job_hash_description.format(task_hash)),
only_fields=['id'],
)
for obj in potential_tasks:
task = Task.get_task(task_id=obj.id)
if task_hash == cls._create_task_hash(task):
return task
return None
@classmethod
def _set_task_cache_hash(cls, task, task_hash=None):
# type: (Task, Optional[str]) -> ()
"""
Store the task state hash for later querying
:param task: The Task object that was created
:param task_hash: The Task Hash (string) to store, if None generate a new task_hash from the Task
"""
if not task:
return
if not task_hash:
task_hash = cls._create_task_hash(task=task)
hash_comment = cls._job_hash_description.format(task_hash) + '\n'
task.set_comment(task.comment + '\n' + hash_comment if task.comment else hash_comment)
class TrainsJob(ClearmlJob): class TrainsJob(ClearmlJob):
@ -317,6 +432,7 @@ class TrainsJob(ClearmlJob):
DeprecationWarning, DeprecationWarning,
) )
# noinspection PyMethodMayBeStatic, PyUnusedLocal # noinspection PyMethodMayBeStatic, PyUnusedLocal
class _JobStub(object): class _JobStub(object):
""" """