mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add pipeline step caching
This commit is contained in:
parent
696034ac75
commit
f9e555b464
@ -9,7 +9,6 @@ from attr import attrib, attrs
|
|||||||
from typing import Sequence, Optional, Mapping, Callable, Any, Union
|
from typing import Sequence, Optional, Mapping, Callable, Any, Union
|
||||||
|
|
||||||
from ..backend_interface.util import get_or_create_project
|
from ..backend_interface.util import get_or_create_project
|
||||||
from ..config import get_remote_task_id
|
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
from ..task import Task
|
from ..task import Task
|
||||||
from ..automation import ClearmlJob
|
from ..automation import ClearmlJob
|
||||||
@ -41,6 +40,7 @@ class PipelineController(object):
|
|||||||
clone_task = attrib(type=bool, default=True)
|
clone_task = attrib(type=bool, default=True)
|
||||||
job = attrib(type=ClearmlJob, default=None)
|
job = attrib(type=ClearmlJob, default=None)
|
||||||
skip_job = attrib(type=bool, default=False)
|
skip_job = attrib(type=bool, default=False)
|
||||||
|
cache_executed_step = attrib(type=bool, default=False)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -125,7 +125,7 @@ class PipelineController(object):
|
|||||||
clone_base_task=True, # type: bool
|
clone_base_task=True, # type: bool
|
||||||
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||||
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||||
|
cache_executed_step=False, # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> bool
|
# type: (...) -> bool
|
||||||
"""
|
"""
|
||||||
@ -199,6 +199,13 @@ class PipelineController(object):
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
:param cache_executed_step: If True, before launching the new step,
|
||||||
|
after updating with the latest configuration, check if an exact Task with the same parameter/code
|
||||||
|
was already executed. If it was found, use it instead of launching a new Task.
|
||||||
|
Default: False, a new cloned copy of base_task is always used.
|
||||||
|
Notice: If the git repo reference does not have a specific commit ID, the Task will never be used.
|
||||||
|
If `clone_base_task` is False there is no cloning, hence the base_task is used.
|
||||||
|
|
||||||
:return: True if successful
|
:return: True if successful
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -219,7 +226,17 @@ class PipelineController(object):
|
|||||||
if not base_task_id:
|
if not base_task_id:
|
||||||
if not base_task_project or not base_task_name:
|
if not base_task_project or not base_task_name:
|
||||||
raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
|
raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
|
||||||
base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name)
|
base_task = Task.get_task(
|
||||||
|
project_name=base_task_project,
|
||||||
|
task_name=base_task_name,
|
||||||
|
allow_archived=True,
|
||||||
|
task_filter=dict(
|
||||||
|
status=[str(Task.TaskStatusEnum.created), str(Task.TaskStatusEnum.queued),
|
||||||
|
str(Task.TaskStatusEnum.in_progress), str(Task.TaskStatusEnum.published),
|
||||||
|
str(Task.TaskStatusEnum.stopped), str(Task.TaskStatusEnum.completed),
|
||||||
|
str(Task.TaskStatusEnum.closed)],
|
||||||
|
)
|
||||||
|
)
|
||||||
if not base_task:
|
if not base_task:
|
||||||
raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
|
raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
|
||||||
base_task_project, base_task_name))
|
base_task_project, base_task_name))
|
||||||
@ -235,6 +252,7 @@ class PipelineController(object):
|
|||||||
parameters=parameter_override or {},
|
parameters=parameter_override or {},
|
||||||
clone_task=clone_base_task,
|
clone_task=clone_base_task,
|
||||||
task_overrides=task_overrides,
|
task_overrides=task_overrides,
|
||||||
|
cache_executed_step=cache_executed_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._task and not self._task.running_locally():
|
if self._task and not self._task.running_locally():
|
||||||
@ -606,6 +624,7 @@ class PipelineController(object):
|
|||||||
parent=self._task.id if self._task else None,
|
parent=self._task.id if self._task else None,
|
||||||
disable_clone_task=not node.clone_task,
|
disable_clone_task=not node.clone_task,
|
||||||
task_overrides=task_overrides,
|
task_overrides=task_overrides,
|
||||||
|
allow_caching=node.cache_executed_step,
|
||||||
**extra_args
|
**extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -619,6 +638,8 @@ class PipelineController(object):
|
|||||||
# delete the job we just created
|
# delete the job we just created
|
||||||
node.job.delete()
|
node.job.delete()
|
||||||
node.skip_job = True
|
node.skip_job = True
|
||||||
|
elif node.job.is_cached_task():
|
||||||
|
node.executed = node.job.task_id()
|
||||||
else:
|
else:
|
||||||
node.job.launch(queue_name=node.queue or self._default_execution_queue)
|
node.job.launch(queue_name=node.queue or self._default_execution_queue)
|
||||||
|
|
||||||
@ -674,11 +695,8 @@ class PipelineController(object):
|
|||||||
'{}<br />'.format(node.name) +
|
'{}<br />'.format(node.name) +
|
||||||
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
|
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
|
||||||
for k, v in (node.parameters or {}).items()))
|
for k, v in (node.parameters or {}).items()))
|
||||||
sankey_node['color'].append(
|
|
||||||
("red" if node.job and node.job.is_failed() else
|
sankey_node['color'].append(self._get_node_color(node))
|
||||||
("blue" if not node.job or node.job.is_completed() else "royalblue"))
|
|
||||||
if node.executed is not None else
|
|
||||||
("green" if node.job else ("gray" if node.skip_job else "lightsteelblue")))
|
|
||||||
|
|
||||||
for p in parents:
|
for p in parents:
|
||||||
sankey_link['source'].append(p)
|
sankey_link['source'].append(p)
|
||||||
@ -748,6 +766,36 @@ class PipelineController(object):
|
|||||||
self._task.get_logger().report_table(
|
self._task.get_logger().report_table(
|
||||||
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
|
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_node_color(node):
|
||||||
|
# type (self.Mode) -> str
|
||||||
|
"""
|
||||||
|
Return the node color based on the node/job state
|
||||||
|
:param node: A node in the pipeline
|
||||||
|
:return: string representing the color of the node (e.g. "red", "green", etc)
|
||||||
|
"""
|
||||||
|
if not node:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if node.executed is not None:
|
||||||
|
if node.job and node.job.is_failed():
|
||||||
|
return "red" # failed job
|
||||||
|
elif node.job and node.job.is_cached_task():
|
||||||
|
return "darkslateblue"
|
||||||
|
elif not node.job or node.job.is_completed():
|
||||||
|
return "blue" # completed job
|
||||||
|
else:
|
||||||
|
return "royalblue" # aborted job
|
||||||
|
elif node.job:
|
||||||
|
if node.job.is_pending():
|
||||||
|
return "mediumseagreen" # pending in queue
|
||||||
|
else:
|
||||||
|
return "green" # running job
|
||||||
|
elif node.skip_job:
|
||||||
|
return "gray" # skipped job
|
||||||
|
else:
|
||||||
|
return "lightsteelblue" # pending job
|
||||||
|
|
||||||
def _force_task_configuration_update(self):
|
def _force_task_configuration_update(self):
|
||||||
pipeline_dag = self._serialize()
|
pipeline_dag = self._serialize()
|
||||||
if self._task:
|
if self._task:
|
||||||
@ -1036,9 +1084,12 @@ class PipelineController(object):
|
|||||||
return "pending"
|
return "pending"
|
||||||
if a_node.skip_job:
|
if a_node.skip_job:
|
||||||
return "skipped"
|
return "skipped"
|
||||||
|
if a_node.job and a_node.job.is_cached_task():
|
||||||
|
return "cached"
|
||||||
if a_node.job and a_node.job.task:
|
if a_node.job and a_node.job.task:
|
||||||
|
# no need to refresh status
|
||||||
return str(a_node.job.task.data.status)
|
return str(a_node.job.task.data.status)
|
||||||
if a_node.job and a_node.job.executed:
|
if a_node.executed:
|
||||||
return "executed"
|
return "executed"
|
||||||
return "pending"
|
return "pending"
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from logging import getLogger
|
|||||||
from time import time, sleep
|
from time import time, sleep
|
||||||
from typing import Optional, Mapping, Sequence, Any
|
from typing import Optional, Mapping, Sequence, Any
|
||||||
|
|
||||||
|
from ..storage.util import hash_dict
|
||||||
from ..task import Task
|
from ..task import Task
|
||||||
from ..backend_api.services import tasks as tasks_service
|
from ..backend_api.services import tasks as tasks_service
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ logger = getLogger('clearml.automation.job')
|
|||||||
|
|
||||||
|
|
||||||
class ClearmlJob(object):
|
class ClearmlJob(object):
|
||||||
|
_job_hash_description = 'job_hash={}'
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_task_id, # type: str
|
base_task_id, # type: str
|
||||||
@ -21,6 +24,7 @@ class ClearmlJob(object):
|
|||||||
tags=None, # type: Optional[Sequence[str]]
|
tags=None, # type: Optional[Sequence[str]]
|
||||||
parent=None, # type: Optional[str]
|
parent=None, # type: Optional[str]
|
||||||
disable_clone_task=False, # type: bool
|
disable_clone_task=False, # type: bool
|
||||||
|
allow_caching=False, # type: bool
|
||||||
**kwargs # type: Any
|
**kwargs # type: Any
|
||||||
):
|
):
|
||||||
# type: (...) -> ()
|
# type: (...) -> ()
|
||||||
@ -34,10 +38,13 @@ class ClearmlJob(object):
|
|||||||
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
||||||
:param dict kwargs: additional Task creation parameters
|
:param dict kwargs: additional Task creation parameters
|
||||||
:param bool disable_clone_task: if False (default) clone base task id.
|
:param bool disable_clone_task: if False (default) clone base task id.
|
||||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||||
|
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
|
||||||
|
If we do, use it and set internal is_cached flag. Default False (always create new Task).
|
||||||
"""
|
"""
|
||||||
|
base_temp_task = Task.get_task(task_id=base_task_id)
|
||||||
if disable_clone_task:
|
if disable_clone_task:
|
||||||
self.task = Task.get_task(task_id=base_task_id)
|
self.task = base_temp_task
|
||||||
task_status = self.task.status
|
task_status = self.task.status
|
||||||
if task_status != Task.TaskStatusEnum.created:
|
if task_status != Task.TaskStatusEnum.created:
|
||||||
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
|
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
|
||||||
@ -47,31 +54,56 @@ class ClearmlJob(object):
|
|||||||
elif parent:
|
elif parent:
|
||||||
self.task.set_parent(parent)
|
self.task.set_parent(parent)
|
||||||
|
|
||||||
|
self.task_parameter_override = None
|
||||||
|
task_params = None
|
||||||
|
if parameter_override:
|
||||||
|
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
|
||||||
|
task_params.update(parameter_override)
|
||||||
|
self.task_parameter_override = dict(**parameter_override)
|
||||||
|
|
||||||
|
sections = {}
|
||||||
|
if task_overrides:
|
||||||
|
# set values inside the Task
|
||||||
|
for k, v in task_overrides.items():
|
||||||
|
# notice we can allow ourselves to change the base-task object as we will not use it any further
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
||||||
|
section = k.split('.')[0]
|
||||||
|
sections[section] = getattr(base_temp_task.data, section, None)
|
||||||
|
|
||||||
|
# check cached task
|
||||||
|
self._is_cached_task = False
|
||||||
|
task_hash = None
|
||||||
|
if allow_caching and not disable_clone_task or not self.task:
|
||||||
|
# look for a cached copy of the Task
|
||||||
|
# get parameters + task_overrides + as dict and hash it.
|
||||||
|
task_hash = self._create_task_hash(
|
||||||
|
base_temp_task, section_overrides=sections, params_override=task_params)
|
||||||
|
task = self._get_cached_task(task_hash)
|
||||||
|
# if we found a task, just use
|
||||||
|
if task:
|
||||||
|
self._is_cached_task = True
|
||||||
|
self.task = task
|
||||||
|
self.task_started = True
|
||||||
|
self._worker = None
|
||||||
|
return
|
||||||
|
|
||||||
# check again if we need to clone the Task
|
# check again if we need to clone the Task
|
||||||
if not disable_clone_task:
|
if not disable_clone_task:
|
||||||
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
|
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
||||||
self.task_parameter_override = None
|
|
||||||
if parameter_override:
|
|
||||||
params = self.task.get_parameters(backwards_compatibility=False)
|
|
||||||
params.update(parameter_override)
|
|
||||||
self.task.set_parameters(params)
|
|
||||||
self.task_parameter_override = dict(**parameter_override)
|
|
||||||
|
|
||||||
if task_overrides:
|
if task_params:
|
||||||
sections = {}
|
self.task.set_parameters(task_params)
|
||||||
# set values inside the Task
|
|
||||||
for k, v in task_overrides.items():
|
if task_overrides and sections:
|
||||||
# noinspection PyProtectedMember
|
|
||||||
self.task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
|
||||||
section = k.split('.')[0]
|
|
||||||
sections[section] = getattr(self.task.data, section, None)
|
|
||||||
# store back Task parameters into backend
|
# store back Task parameters into backend
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
self.task._edit(**sections)
|
self.task._edit(**sections)
|
||||||
|
|
||||||
|
self._set_task_cache_hash(self.task, task_hash)
|
||||||
self.task_started = False
|
self.task_started = False
|
||||||
self._worker = None
|
self._worker = None
|
||||||
|
|
||||||
@ -108,22 +140,29 @@ class ClearmlJob(object):
|
|||||||
return metrics, title, series, values
|
return metrics, title, series, values
|
||||||
|
|
||||||
def launch(self, queue_name=None):
|
def launch(self, queue_name=None):
|
||||||
# type: (str) -> ()
|
# type: (str) -> bool
|
||||||
"""
|
"""
|
||||||
Send Job for execution on the requested execution queue
|
Send Job for execution on the requested execution queue
|
||||||
|
|
||||||
:param str queue_name:
|
:param str queue_name:
|
||||||
|
:return False if Task is not in "created" status (i.e. cannot be enqueued)
|
||||||
"""
|
"""
|
||||||
|
if self._is_cached_task:
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
Task.enqueue(task=self.task, queue_name=queue_name)
|
Task.enqueue(task=self.task, queue_name=queue_name)
|
||||||
|
return True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(ex)
|
logger.warning(ex)
|
||||||
|
return False
|
||||||
|
|
||||||
def abort(self):
|
def abort(self):
|
||||||
# type: () -> ()
|
# type: () -> ()
|
||||||
"""
|
"""
|
||||||
Abort currently running job (can be called multiple times)
|
Abort currently running job (can be called multiple times)
|
||||||
"""
|
"""
|
||||||
|
if not self.task or self._is_cached_task:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
self.task.stopped()
|
self.task.stopped()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -234,7 +273,7 @@ class ClearmlJob(object):
|
|||||||
def is_stopped(self):
|
def is_stopped(self):
|
||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
"""
|
||||||
Return True, if job is has executed and is not any more
|
Return True, if job finished executing (for any reason)
|
||||||
|
|
||||||
:return: True the task is currently one of these states, stopped / completed / failed / published.
|
:return: True the task is currently one of these states, stopped / completed / failed / published.
|
||||||
"""
|
"""
|
||||||
@ -298,7 +337,7 @@ class ClearmlJob(object):
|
|||||||
Delete the current temporary job (before launching)
|
Delete the current temporary job (before launching)
|
||||||
Return False if the Job/Task could not deleted
|
Return False if the Job/Task could not deleted
|
||||||
"""
|
"""
|
||||||
if not self.task:
|
if not self.task or self._is_cached_task:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.task.delete():
|
if self.task.delete():
|
||||||
@ -307,6 +346,82 @@ class ClearmlJob(object):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_cached_task(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
:return: True if the internal Task is a cached one, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._is_cached_task
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_task_hash(cls, task, section_overrides=None, params_override=None):
|
||||||
|
# type: (Task, Optional[dict], Optional[dict]) -> Optional[str]
|
||||||
|
"""
|
||||||
|
Create Hash (str) representing the state of the Task
|
||||||
|
:param task: A Task to hash
|
||||||
|
:param section_overrides: optional dict (keys are Task's section names) with task overrides.
|
||||||
|
:param params_override: Alternative to the entire Task's hyper parameters section
|
||||||
|
(notice this should not be a nested dict but a flat key/value)
|
||||||
|
:return: str crc32 of the Task configuration
|
||||||
|
"""
|
||||||
|
if not task:
|
||||||
|
return None
|
||||||
|
if section_overrides and section_overrides.get('script'):
|
||||||
|
script = section_overrides['script']
|
||||||
|
if not isinstance(script, dict):
|
||||||
|
script = script.to_dict()
|
||||||
|
else:
|
||||||
|
script = task.data.script.to_dict() if task.data.script else {}
|
||||||
|
|
||||||
|
# if we have a repository, we must make sure we have a specific version_num to ensure consistency
|
||||||
|
if script.get('repository') and not script.get('version_num') and not script.get('tag'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# we need to ignore `requirements` section because ir might be changing from run to run
|
||||||
|
script.pop("requirements", None)
|
||||||
|
|
||||||
|
hyper_params = task.get_parameters() if params_override is None else params_override
|
||||||
|
configs = task.get_configuration_objects()
|
||||||
|
return hash_dict(dict(script=script, hyper_params=hyper_params, configs=configs), hash_func='crc32')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_cached_task(cls, task_hash):
|
||||||
|
# type: (str) -> Optional[Task]
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param task_hash:
|
||||||
|
:return: A task matching the requested task hash
|
||||||
|
"""
|
||||||
|
if not task_hash:
|
||||||
|
return None
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
potential_tasks = Task._query_tasks(
|
||||||
|
status=['completed', 'stopped', 'published'],
|
||||||
|
system_tags=['-{}'.format(Task.archived_tag)],
|
||||||
|
_all_=dict(fields=['comment'], pattern=cls._job_hash_description.format(task_hash)),
|
||||||
|
only_fields=['id'],
|
||||||
|
)
|
||||||
|
for obj in potential_tasks:
|
||||||
|
task = Task.get_task(task_id=obj.id)
|
||||||
|
if task_hash == cls._create_task_hash(task):
|
||||||
|
return task
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_task_cache_hash(cls, task, task_hash=None):
|
||||||
|
# type: (Task, Optional[str]) -> ()
|
||||||
|
"""
|
||||||
|
Store the task state hash for later querying
|
||||||
|
:param task: The Task object that was created
|
||||||
|
:param task_hash: The Task Hash (string) to store, if None generate a new task_hash from the Task
|
||||||
|
"""
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
if not task_hash:
|
||||||
|
task_hash = cls._create_task_hash(task=task)
|
||||||
|
hash_comment = cls._job_hash_description.format(task_hash) + '\n'
|
||||||
|
task.set_comment(task.comment + '\n' + hash_comment if task.comment else hash_comment)
|
||||||
|
|
||||||
|
|
||||||
class TrainsJob(ClearmlJob):
|
class TrainsJob(ClearmlJob):
|
||||||
|
|
||||||
@ -317,6 +432,7 @@ class TrainsJob(ClearmlJob):
|
|||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyMethodMayBeStatic, PyUnusedLocal
|
# noinspection PyMethodMayBeStatic, PyUnusedLocal
|
||||||
class _JobStub(object):
|
class _JobStub(object):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user