Add pipeline step caching

This commit is contained in:
allegroai 2021-04-25 10:41:51 +03:00
parent 696034ac75
commit f9e555b464
2 changed files with 195 additions and 28 deletions

View File

@ -9,7 +9,6 @@ from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union
from ..backend_interface.util import get_or_create_project
from ..config import get_remote_task_id
from ..debugging.log import LoggerRoot
from ..task import Task
from ..automation import ClearmlJob
@ -41,6 +40,7 @@ class PipelineController(object):
clone_task = attrib(type=bool, default=True)
job = attrib(type=ClearmlJob, default=None)
skip_job = attrib(type=bool, default=False)
cache_executed_step = attrib(type=bool, default=False)
def __init__(
self,
@ -125,7 +125,7 @@ class PipelineController(object):
clone_base_task=True, # type: bool
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
cache_executed_step=False, # type: bool
):
# type: (...) -> bool
"""
@ -199,6 +199,13 @@ class PipelineController(object):
):
pass
:param cache_executed_step: If True, before launching the new step,
after updating with the latest configuration, check if an exact Task with the same parameter/code
was already executed. If it was found, use it instead of launching a new Task.
Default: False, a new cloned copy of base_task is always used.
Notice: If the git repo reference does not have a specific commit ID, the Task will never be used.
If `clone_base_task` is False there is no cloning, hence the base_task is used.
:return: True if successful
"""
@ -219,7 +226,17 @@ class PipelineController(object):
if not base_task_id:
if not base_task_project or not base_task_name:
raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name)
base_task = Task.get_task(
project_name=base_task_project,
task_name=base_task_name,
allow_archived=True,
task_filter=dict(
status=[str(Task.TaskStatusEnum.created), str(Task.TaskStatusEnum.queued),
str(Task.TaskStatusEnum.in_progress), str(Task.TaskStatusEnum.published),
str(Task.TaskStatusEnum.stopped), str(Task.TaskStatusEnum.completed),
str(Task.TaskStatusEnum.closed)],
)
)
if not base_task:
raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
base_task_project, base_task_name))
@ -235,6 +252,7 @@ class PipelineController(object):
parameters=parameter_override or {},
clone_task=clone_base_task,
task_overrides=task_overrides,
cache_executed_step=cache_executed_step,
)
if self._task and not self._task.running_locally():
@ -606,6 +624,7 @@ class PipelineController(object):
parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task,
task_overrides=task_overrides,
allow_caching=node.cache_executed_step,
**extra_args
)
@ -619,6 +638,8 @@ class PipelineController(object):
# delete the job we just created
node.job.delete()
node.skip_job = True
elif node.job.is_cached_task():
node.executed = node.job.task_id()
else:
node.job.launch(queue_name=node.queue or self._default_execution_queue)
@ -674,11 +695,8 @@ class PipelineController(object):
'{}<br />'.format(node.name) +
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
for k, v in (node.parameters or {}).items()))
sankey_node['color'].append(
("red" if node.job and node.job.is_failed() else
("blue" if not node.job or node.job.is_completed() else "royalblue"))
if node.executed is not None else
("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
sankey_node['color'].append(self._get_node_color(node))
for p in parents:
sankey_link['source'].append(p)
@ -748,6 +766,36 @@ class PipelineController(object):
self._task.get_logger().report_table(
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
@staticmethod
def _get_node_color(node):
# type (self.Mode) -> str
"""
Return the node color based on the node/job state
:param node: A node in the pipeline
:return: string representing the color of the node (e.g. "red", "green", etc)
"""
if not node:
return ""
if node.executed is not None:
if node.job and node.job.is_failed():
return "red" # failed job
elif node.job and node.job.is_cached_task():
return "darkslateblue"
elif not node.job or node.job.is_completed():
return "blue" # completed job
else:
return "royalblue" # aborted job
elif node.job:
if node.job.is_pending():
return "mediumseagreen" # pending in queue
else:
return "green" # running job
elif node.skip_job:
return "gray" # skipped job
else:
return "lightsteelblue" # pending job
def _force_task_configuration_update(self):
pipeline_dag = self._serialize()
if self._task:
@ -1036,9 +1084,12 @@ class PipelineController(object):
return "pending"
if a_node.skip_job:
return "skipped"
if a_node.job and a_node.job.is_cached_task():
return "cached"
if a_node.job and a_node.job.task:
# no need to refresh status
return str(a_node.job.task.data.status)
if a_node.job and a_node.job.executed:
if a_node.executed:
return "executed"
return "pending"

View File

@ -5,6 +5,7 @@ from logging import getLogger
from time import time, sleep
from typing import Optional, Mapping, Sequence, Any
from ..storage.util import hash_dict
from ..task import Task
from ..backend_api.services import tasks as tasks_service
@ -13,6 +14,8 @@ logger = getLogger('clearml.automation.job')
class ClearmlJob(object):
_job_hash_description = 'job_hash={}'
def __init__(
self,
base_task_id, # type: str
@ -21,6 +24,7 @@ class ClearmlJob(object):
tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str]
disable_clone_task=False, # type: bool
allow_caching=False, # type: bool
**kwargs # type: Any
):
# type: (...) -> ()
@ -34,10 +38,13 @@ class ClearmlJob(object):
:param str parent: Set newly created Task parent task field, default: base_tak_id.
:param dict kwargs: additional Task creation parameters
:param bool disable_clone_task: if False (default) clone base task id.
If True, use the base_task_id directly (base-task must be in draft-mode / created),
If True, use the base_task_id directly (base-task must be in draft-mode / created),
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
If we do, use it and set internal is_cached flag. Default False (always create new Task).
"""
base_temp_task = Task.get_task(task_id=base_task_id)
if disable_clone_task:
self.task = Task.get_task(task_id=base_task_id)
self.task = base_temp_task
task_status = self.task.status
if task_status != Task.TaskStatusEnum.created:
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
@ -47,31 +54,56 @@ class ClearmlJob(object):
elif parent:
self.task.set_parent(parent)
self.task_parameter_override = None
task_params = None
if parameter_override:
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
task_params.update(parameter_override)
self.task_parameter_override = dict(**parameter_override)
sections = {}
if task_overrides:
# set values inside the Task
for k, v in task_overrides.items():
# notice we can allow ourselves to change the base-task object as we will not use it any further
# noinspection PyProtectedMember
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
section = k.split('.')[0]
sections[section] = getattr(base_temp_task.data, section, None)
# check cached task
self._is_cached_task = False
task_hash = None
if allow_caching and not disable_clone_task or not self.task:
# look for a cached copy of the Task
# get parameters + task_overrides + as dict and hash it.
task_hash = self._create_task_hash(
base_temp_task, section_overrides=sections, params_override=task_params)
task = self._get_cached_task(task_hash)
# if we found a task, just use
if task:
self._is_cached_task = True
self.task = task
self.task_started = True
self._worker = None
return
# check again if we need to clone the Task
if not disable_clone_task:
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
if tags:
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
self.task_parameter_override = None
if parameter_override:
params = self.task.get_parameters(backwards_compatibility=False)
params.update(parameter_override)
self.task.set_parameters(params)
self.task_parameter_override = dict(**parameter_override)
if task_overrides:
sections = {}
# set values inside the Task
for k, v in task_overrides.items():
# noinspection PyProtectedMember
self.task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
section = k.split('.')[0]
sections[section] = getattr(self.task.data, section, None)
if task_params:
self.task.set_parameters(task_params)
if task_overrides and sections:
# store back Task parameters into backend
# noinspection PyProtectedMember
self.task._edit(**sections)
self._set_task_cache_hash(self.task, task_hash)
self.task_started = False
self._worker = None
@ -108,22 +140,29 @@ class ClearmlJob(object):
return metrics, title, series, values
def launch(self, queue_name=None):
# type: (str) -> ()
# type: (str) -> bool
"""
Send Job for execution on the requested execution queue
:param str queue_name:
:return False if Task is not in "created" status (i.e. cannot be enqueued)
"""
if self._is_cached_task:
return False
try:
Task.enqueue(task=self.task, queue_name=queue_name)
return True
except Exception as ex:
logger.warning(ex)
return False
def abort(self):
# type: () -> ()
"""
Abort currently running job (can be called multiple times)
"""
if not self.task or self._is_cached_task:
return
try:
self.task.stopped()
except Exception as ex:
@ -234,7 +273,7 @@ class ClearmlJob(object):
def is_stopped(self):
# type: () -> bool
"""
Return True, if job is has executed and is not any more
Return True, if job finished executing (for any reason)
:return: True the task is currently one of these states, stopped / completed / failed / published.
"""
@ -298,7 +337,7 @@ class ClearmlJob(object):
Delete the current temporary job (before launching)
Return False if the Job/Task could not deleted
"""
if not self.task:
if not self.task or self._is_cached_task:
return False
if self.task.delete():
@ -307,6 +346,82 @@ class ClearmlJob(object):
return False
def is_cached_task(self):
# type: () -> bool
"""
:return: True if the internal Task is a cached one, False otherwise.
"""
return self._is_cached_task
@classmethod
def _create_task_hash(cls, task, section_overrides=None, params_override=None):
# type: (Task, Optional[dict], Optional[dict]) -> Optional[str]
"""
Create Hash (str) representing the state of the Task
:param task: A Task to hash
:param section_overrides: optional dict (keys are Task's section names) with task overrides.
:param params_override: Alternative to the entire Task's hyper parameters section
(notice this should not be a nested dict but a flat key/value)
:return: str crc32 of the Task configuration
"""
if not task:
return None
if section_overrides and section_overrides.get('script'):
script = section_overrides['script']
if not isinstance(script, dict):
script = script.to_dict()
else:
script = task.data.script.to_dict() if task.data.script else {}
# if we have a repository, we must make sure we have a specific version_num to ensure consistency
if script.get('repository') and not script.get('version_num') and not script.get('tag'):
return None
# we need to ignore `requirements` section because ir might be changing from run to run
script.pop("requirements", None)
hyper_params = task.get_parameters() if params_override is None else params_override
configs = task.get_configuration_objects()
return hash_dict(dict(script=script, hyper_params=hyper_params, configs=configs), hash_func='crc32')
@classmethod
def _get_cached_task(cls, task_hash):
# type: (str) -> Optional[Task]
"""
:param task_hash:
:return: A task matching the requested task hash
"""
if not task_hash:
return None
# noinspection PyProtectedMember
potential_tasks = Task._query_tasks(
status=['completed', 'stopped', 'published'],
system_tags=['-{}'.format(Task.archived_tag)],
_all_=dict(fields=['comment'], pattern=cls._job_hash_description.format(task_hash)),
only_fields=['id'],
)
for obj in potential_tasks:
task = Task.get_task(task_id=obj.id)
if task_hash == cls._create_task_hash(task):
return task
return None
@classmethod
def _set_task_cache_hash(cls, task, task_hash=None):
# type: (Task, Optional[str]) -> ()
"""
Store the task state hash for later querying
:param task: The Task object that was created
:param task_hash: The Task Hash (string) to store, if None generate a new task_hash from the Task
"""
if not task:
return
if not task_hash:
task_hash = cls._create_task_hash(task=task)
hash_comment = cls._job_hash_description.format(task_hash) + '\n'
task.set_comment(task.comment + '\n' + hash_comment if task.comment else hash_comment)
class TrainsJob(ClearmlJob):
@ -317,6 +432,7 @@ class TrainsJob(ClearmlJob):
DeprecationWarning,
)
# noinspection PyMethodMayBeStatic, PyUnusedLocal
class _JobStub(object):
"""