Add PipelineController disable clone base task option

This commit is contained in:
allegroai 2021-02-14 13:50:58 +02:00
parent 0c017a7331
commit e5747d587d
2 changed files with 33 additions and 4 deletions

View File

@ -34,6 +34,7 @@ class PipelineController(object):
timeout = attrib(type=float, default=None)
parameters = attrib(type=dict, default={})
executed = attrib(type=str, default=None)
clone_task = attrib(type=bool, default=True)
job = attrib(type=TrainsJob, default=None)
def __init__(
@ -102,6 +103,7 @@ class PipelineController(object):
time_limit=None, # type: Optional[float]
base_task_project=None, # type: Optional[str]
base_task_name=None, # type: Optional[str]
clone_base_task=True, # type: bool
):
# type: (...) -> bool
"""
@ -133,6 +135,8 @@ class PipelineController(object):
use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
:param str base_task_name: If base_task_id is not given,
use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
:param bool clone_base_task: If True (default) the pipeline will clone the base task, and modify/enqueue
the cloned Task. If False, the base-task is used directly, notice it has to be in draft-mode (created).
:return: True if successful
"""
# when running remotely do nothing, we will deserialize ourselves when we start
@ -158,7 +162,9 @@ class PipelineController(object):
self._nodes[name] = self.Node(
name=name, base_task_id=base_task_id, parents=parents or [],
queue=execution_queue, timeout=time_limit,
parameters=parameter_override or {})
parameters=parameter_override or {},
clone_task=clone_base_task,
)
return True
@ -172,7 +178,11 @@ class PipelineController(object):
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'.
If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution.
:param Callable step_task_created_callback: Callback function, called when a step (Task) is created
and before it is sent for execution.
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the TrainsJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the TrainsJob.
Notice the `parameters` are already parsed,
e.g. `${step1.parameters.Args/param}` is replaced with relevant value.
.. code-block:: py
@ -438,7 +448,9 @@ class PipelineController(object):
node.job = TrainsJob(
base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None,
parent=self._task.id if self._task else None)
parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task,
)
if self._experiment_created_cb:
self._experiment_created_cb(node, updated_hyper_parameters)
node.job.launch(queue_name=node.queue or self._default_execution_queue)

View File

@ -19,6 +19,7 @@ class TrainsJob(object):
task_overrides=None, # type: Optional[Mapping[str, str]]
tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str]
disable_clone_task=False, # type: bool
**kwargs # type: Any
):
# type: (...) -> ()
@ -31,8 +32,24 @@ class TrainsJob(object):
:param list tags: additional tags to add to the newly cloned task
:param str parent: Set newly created Task parent task field, default: base_tak_id.
:param dict kwargs: additional Task creation parameters
:param bool disable_clone_task: if False (default) clone base task id.
If True, use the base_task_id directly (base-task must be in draft-mode / created),
"""
if disable_clone_task:
self.task = Task.get_task(task_id=base_task_id)
task_status = self.task.status
if task_status != Task.TaskStatusEnum.created:
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
'Reverting to clone Task'.format(base_task_id, task_status))
disable_clone_task = False
self.task = None
elif parent:
self.task.set_parent(parent)
# check again if we need to clone the Task
if not disable_clone_task:
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
if tags:
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
self.task_parameter_override = None