mirror of
https://github.com/clearml/clearml
synced 2025-05-31 18:48:16 +00:00
Add pipeline step caching
This commit is contained in:
parent
696034ac75
commit
f9e555b464
@ -9,7 +9,6 @@ from attr import attrib, attrs
|
||||
from typing import Sequence, Optional, Mapping, Callable, Any, Union
|
||||
|
||||
from ..backend_interface.util import get_or_create_project
|
||||
from ..config import get_remote_task_id
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..task import Task
|
||||
from ..automation import ClearmlJob
|
||||
@ -41,6 +40,7 @@ class PipelineController(object):
|
||||
clone_task = attrib(type=bool, default=True)
|
||||
job = attrib(type=ClearmlJob, default=None)
|
||||
skip_job = attrib(type=bool, default=False)
|
||||
cache_executed_step = attrib(type=bool, default=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -125,7 +125,7 @@ class PipelineController(object):
|
||||
clone_base_task=True, # type: bool
|
||||
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||
|
||||
cache_executed_step=False, # type: bool
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -199,6 +199,13 @@ class PipelineController(object):
|
||||
):
|
||||
pass
|
||||
|
||||
:param cache_executed_step: If True, before launching the new step,
|
||||
after updating with the latest configuration, check if an exact Task with the same parameter/code
|
||||
was already executed. If it was found, use it instead of launching a new Task.
|
||||
Default: False, a new cloned copy of base_task is always used.
|
||||
Notice: If the git repo reference does not have a specific commit ID, the Task will never be used.
|
||||
If `clone_base_task` is False there is no cloning, hence the base_task is used.
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
|
||||
@ -219,7 +226,17 @@ class PipelineController(object):
|
||||
if not base_task_id:
|
||||
if not base_task_project or not base_task_name:
|
||||
raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
|
||||
base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name)
|
||||
base_task = Task.get_task(
|
||||
project_name=base_task_project,
|
||||
task_name=base_task_name,
|
||||
allow_archived=True,
|
||||
task_filter=dict(
|
||||
status=[str(Task.TaskStatusEnum.created), str(Task.TaskStatusEnum.queued),
|
||||
str(Task.TaskStatusEnum.in_progress), str(Task.TaskStatusEnum.published),
|
||||
str(Task.TaskStatusEnum.stopped), str(Task.TaskStatusEnum.completed),
|
||||
str(Task.TaskStatusEnum.closed)],
|
||||
)
|
||||
)
|
||||
if not base_task:
|
||||
raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
|
||||
base_task_project, base_task_name))
|
||||
@ -235,6 +252,7 @@ class PipelineController(object):
|
||||
parameters=parameter_override or {},
|
||||
clone_task=clone_base_task,
|
||||
task_overrides=task_overrides,
|
||||
cache_executed_step=cache_executed_step,
|
||||
)
|
||||
|
||||
if self._task and not self._task.running_locally():
|
||||
@ -606,6 +624,7 @@ class PipelineController(object):
|
||||
parent=self._task.id if self._task else None,
|
||||
disable_clone_task=not node.clone_task,
|
||||
task_overrides=task_overrides,
|
||||
allow_caching=node.cache_executed_step,
|
||||
**extra_args
|
||||
)
|
||||
|
||||
@ -619,6 +638,8 @@ class PipelineController(object):
|
||||
# delete the job we just created
|
||||
node.job.delete()
|
||||
node.skip_job = True
|
||||
elif node.job.is_cached_task():
|
||||
node.executed = node.job.task_id()
|
||||
else:
|
||||
node.job.launch(queue_name=node.queue or self._default_execution_queue)
|
||||
|
||||
@ -674,11 +695,8 @@ class PipelineController(object):
|
||||
'{}<br />'.format(node.name) +
|
||||
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
|
||||
for k, v in (node.parameters or {}).items()))
|
||||
sankey_node['color'].append(
|
||||
("red" if node.job and node.job.is_failed() else
|
||||
("blue" if not node.job or node.job.is_completed() else "royalblue"))
|
||||
if node.executed is not None else
|
||||
("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
|
||||
|
||||
sankey_node['color'].append(self._get_node_color(node))
|
||||
|
||||
for p in parents:
|
||||
sankey_link['source'].append(p)
|
||||
@ -748,6 +766,36 @@ class PipelineController(object):
|
||||
self._task.get_logger().report_table(
|
||||
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
|
||||
|
||||
@staticmethod
|
||||
def _get_node_color(node):
|
||||
# type (self.Mode) -> str
|
||||
"""
|
||||
Return the node color based on the node/job state
|
||||
:param node: A node in the pipeline
|
||||
:return: string representing the color of the node (e.g. "red", "green", etc)
|
||||
"""
|
||||
if not node:
|
||||
return ""
|
||||
|
||||
if node.executed is not None:
|
||||
if node.job and node.job.is_failed():
|
||||
return "red" # failed job
|
||||
elif node.job and node.job.is_cached_task():
|
||||
return "darkslateblue"
|
||||
elif not node.job or node.job.is_completed():
|
||||
return "blue" # completed job
|
||||
else:
|
||||
return "royalblue" # aborted job
|
||||
elif node.job:
|
||||
if node.job.is_pending():
|
||||
return "mediumseagreen" # pending in queue
|
||||
else:
|
||||
return "green" # running job
|
||||
elif node.skip_job:
|
||||
return "gray" # skipped job
|
||||
else:
|
||||
return "lightsteelblue" # pending job
|
||||
|
||||
def _force_task_configuration_update(self):
|
||||
pipeline_dag = self._serialize()
|
||||
if self._task:
|
||||
@ -1036,9 +1084,12 @@ class PipelineController(object):
|
||||
return "pending"
|
||||
if a_node.skip_job:
|
||||
return "skipped"
|
||||
if a_node.job and a_node.job.is_cached_task():
|
||||
return "cached"
|
||||
if a_node.job and a_node.job.task:
|
||||
# no need to refresh status
|
||||
return str(a_node.job.task.data.status)
|
||||
if a_node.job and a_node.job.executed:
|
||||
if a_node.executed:
|
||||
return "executed"
|
||||
return "pending"
|
||||
|
||||
|
@ -5,6 +5,7 @@ from logging import getLogger
|
||||
from time import time, sleep
|
||||
from typing import Optional, Mapping, Sequence, Any
|
||||
|
||||
from ..storage.util import hash_dict
|
||||
from ..task import Task
|
||||
from ..backend_api.services import tasks as tasks_service
|
||||
|
||||
@ -13,6 +14,8 @@ logger = getLogger('clearml.automation.job')
|
||||
|
||||
|
||||
class ClearmlJob(object):
|
||||
_job_hash_description = 'job_hash={}'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
@ -21,6 +24,7 @@ class ClearmlJob(object):
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
parent=None, # type: Optional[str]
|
||||
disable_clone_task=False, # type: bool
|
||||
allow_caching=False, # type: bool
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> ()
|
||||
@ -34,10 +38,13 @@ class ClearmlJob(object):
|
||||
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
||||
:param dict kwargs: additional Task creation parameters
|
||||
:param bool disable_clone_task: if False (default) clone base task id.
|
||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
|
||||
If we do, use it and set internal is_cached flag. Default False (always create new Task).
|
||||
"""
|
||||
base_temp_task = Task.get_task(task_id=base_task_id)
|
||||
if disable_clone_task:
|
||||
self.task = Task.get_task(task_id=base_task_id)
|
||||
self.task = base_temp_task
|
||||
task_status = self.task.status
|
||||
if task_status != Task.TaskStatusEnum.created:
|
||||
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
|
||||
@ -47,31 +54,56 @@ class ClearmlJob(object):
|
||||
elif parent:
|
||||
self.task.set_parent(parent)
|
||||
|
||||
self.task_parameter_override = None
|
||||
task_params = None
|
||||
if parameter_override:
|
||||
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
|
||||
task_params.update(parameter_override)
|
||||
self.task_parameter_override = dict(**parameter_override)
|
||||
|
||||
sections = {}
|
||||
if task_overrides:
|
||||
# set values inside the Task
|
||||
for k, v in task_overrides.items():
|
||||
# notice we can allow ourselves to change the base-task object as we will not use it any further
|
||||
# noinspection PyProtectedMember
|
||||
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
||||
section = k.split('.')[0]
|
||||
sections[section] = getattr(base_temp_task.data, section, None)
|
||||
|
||||
# check cached task
|
||||
self._is_cached_task = False
|
||||
task_hash = None
|
||||
if allow_caching and not disable_clone_task or not self.task:
|
||||
# look for a cached copy of the Task
|
||||
# get parameters + task_overrides + as dict and hash it.
|
||||
task_hash = self._create_task_hash(
|
||||
base_temp_task, section_overrides=sections, params_override=task_params)
|
||||
task = self._get_cached_task(task_hash)
|
||||
# if we found a task, just use
|
||||
if task:
|
||||
self._is_cached_task = True
|
||||
self.task = task
|
||||
self.task_started = True
|
||||
self._worker = None
|
||||
return
|
||||
|
||||
# check again if we need to clone the Task
|
||||
if not disable_clone_task:
|
||||
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
|
||||
|
||||
if tags:
|
||||
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
||||
self.task_parameter_override = None
|
||||
if parameter_override:
|
||||
params = self.task.get_parameters(backwards_compatibility=False)
|
||||
params.update(parameter_override)
|
||||
self.task.set_parameters(params)
|
||||
self.task_parameter_override = dict(**parameter_override)
|
||||
|
||||
if task_overrides:
|
||||
sections = {}
|
||||
# set values inside the Task
|
||||
for k, v in task_overrides.items():
|
||||
# noinspection PyProtectedMember
|
||||
self.task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
||||
section = k.split('.')[0]
|
||||
sections[section] = getattr(self.task.data, section, None)
|
||||
if task_params:
|
||||
self.task.set_parameters(task_params)
|
||||
|
||||
if task_overrides and sections:
|
||||
# store back Task parameters into backend
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(**sections)
|
||||
|
||||
self._set_task_cache_hash(self.task, task_hash)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
|
||||
@ -108,22 +140,29 @@ class ClearmlJob(object):
|
||||
return metrics, title, series, values
|
||||
|
||||
def launch(self, queue_name=None):
|
||||
# type: (str) -> ()
|
||||
# type: (str) -> bool
|
||||
"""
|
||||
Send Job for execution on the requested execution queue
|
||||
|
||||
:param str queue_name:
|
||||
:return False if Task is not in "created" status (i.e. cannot be enqueued)
|
||||
"""
|
||||
if self._is_cached_task:
|
||||
return False
|
||||
try:
|
||||
Task.enqueue(task=self.task, queue_name=queue_name)
|
||||
return True
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
return False
|
||||
|
||||
def abort(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Abort currently running job (can be called multiple times)
|
||||
"""
|
||||
if not self.task or self._is_cached_task:
|
||||
return
|
||||
try:
|
||||
self.task.stopped()
|
||||
except Exception as ex:
|
||||
@ -234,7 +273,7 @@ class ClearmlJob(object):
|
||||
def is_stopped(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True, if job is has executed and is not any more
|
||||
Return True, if job finished executing (for any reason)
|
||||
|
||||
:return: True the task is currently one of these states, stopped / completed / failed / published.
|
||||
"""
|
||||
@ -298,7 +337,7 @@ class ClearmlJob(object):
|
||||
Delete the current temporary job (before launching)
|
||||
Return False if the Job/Task could not deleted
|
||||
"""
|
||||
if not self.task:
|
||||
if not self.task or self._is_cached_task:
|
||||
return False
|
||||
|
||||
if self.task.delete():
|
||||
@ -307,6 +346,82 @@ class ClearmlJob(object):
|
||||
|
||||
return False
|
||||
|
||||
def is_cached_task(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
:return: True if the internal Task is a cached one, False otherwise.
|
||||
"""
|
||||
return self._is_cached_task
|
||||
|
||||
@classmethod
|
||||
def _create_task_hash(cls, task, section_overrides=None, params_override=None):
|
||||
# type: (Task, Optional[dict], Optional[dict]) -> Optional[str]
|
||||
"""
|
||||
Create Hash (str) representing the state of the Task
|
||||
:param task: A Task to hash
|
||||
:param section_overrides: optional dict (keys are Task's section names) with task overrides.
|
||||
:param params_override: Alternative to the entire Task's hyper parameters section
|
||||
(notice this should not be a nested dict but a flat key/value)
|
||||
:return: str crc32 of the Task configuration
|
||||
"""
|
||||
if not task:
|
||||
return None
|
||||
if section_overrides and section_overrides.get('script'):
|
||||
script = section_overrides['script']
|
||||
if not isinstance(script, dict):
|
||||
script = script.to_dict()
|
||||
else:
|
||||
script = task.data.script.to_dict() if task.data.script else {}
|
||||
|
||||
# if we have a repository, we must make sure we have a specific version_num to ensure consistency
|
||||
if script.get('repository') and not script.get('version_num') and not script.get('tag'):
|
||||
return None
|
||||
|
||||
# we need to ignore `requirements` section because ir might be changing from run to run
|
||||
script.pop("requirements", None)
|
||||
|
||||
hyper_params = task.get_parameters() if params_override is None else params_override
|
||||
configs = task.get_configuration_objects()
|
||||
return hash_dict(dict(script=script, hyper_params=hyper_params, configs=configs), hash_func='crc32')
|
||||
|
||||
@classmethod
|
||||
def _get_cached_task(cls, task_hash):
|
||||
# type: (str) -> Optional[Task]
|
||||
"""
|
||||
|
||||
:param task_hash:
|
||||
:return: A task matching the requested task hash
|
||||
"""
|
||||
if not task_hash:
|
||||
return None
|
||||
# noinspection PyProtectedMember
|
||||
potential_tasks = Task._query_tasks(
|
||||
status=['completed', 'stopped', 'published'],
|
||||
system_tags=['-{}'.format(Task.archived_tag)],
|
||||
_all_=dict(fields=['comment'], pattern=cls._job_hash_description.format(task_hash)),
|
||||
only_fields=['id'],
|
||||
)
|
||||
for obj in potential_tasks:
|
||||
task = Task.get_task(task_id=obj.id)
|
||||
if task_hash == cls._create_task_hash(task):
|
||||
return task
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _set_task_cache_hash(cls, task, task_hash=None):
|
||||
# type: (Task, Optional[str]) -> ()
|
||||
"""
|
||||
Store the task state hash for later querying
|
||||
:param task: The Task object that was created
|
||||
:param task_hash: The Task Hash (string) to store, if None generate a new task_hash from the Task
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
if not task_hash:
|
||||
task_hash = cls._create_task_hash(task=task)
|
||||
hash_comment = cls._job_hash_description.format(task_hash) + '\n'
|
||||
task.set_comment(task.comment + '\n' + hash_comment if task.comment else hash_comment)
|
||||
|
||||
|
||||
class TrainsJob(ClearmlJob):
|
||||
|
||||
@ -317,6 +432,7 @@ class TrainsJob(ClearmlJob):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic, PyUnusedLocal
|
||||
class _JobStub(object):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user