Add pipeline controller tag-per-step option. Pipeline controller parameters can now be any type.

This commit is contained in:
allegroai 2020-09-05 16:31:01 +03:00
parent 86aa3aaa98
commit 03e7ebd48c
2 changed files with 38 additions and 15 deletions

View File

@ -2,14 +2,15 @@ from trains import Task
from trains.automation.controller import PipelineController from trains.automation.controller import PipelineController
task = Task.init(project_name='examples', task_name='pipeline demo', task_type=Task.TaskTypes.controller) task = Task.init(project_name='examples', task_name='pipeline demo',
task_type=Task.TaskTypes.controller, reuse_last_task_id=False)
pipe = PipelineController(default_execution_queue='default') pipe = PipelineController(default_execution_queue='default', add_pipeline_tags=False)
pipe.add_step(name='stage_data', base_task_project='examples', base_task_name='pipeline step 1 dataset artifact') pipe.add_step(name='stage_data', base_task_project='examples', base_task_name='pipeline step 1 dataset artifact')
pipe.add_step(name='stage_process', parents=['stage_data', ], pipe.add_step(name='stage_process', parents=['stage_data', ],
base_task_project='examples', base_task_name='pipeline step 2 process dataset', base_task_project='examples', base_task_name='pipeline step 2 process dataset',
parameter_override={'General/dataset_url': '${stage_data.artifacts.dataset.url}', parameter_override={'General/dataset_url': '${stage_data.artifacts.dataset.url}',
'General/test_size': '0.25'}) 'General/test_size': 0.25})
pipe.add_step(name='stage_train', parents=['stage_process', ], pipe.add_step(name='stage_train', parents=['stage_process', ],
base_task_project='examples', base_task_name='pipeline step 3 train model', base_task_project='examples', base_task_name='pipeline step 3 train model',
parameter_override={'General/dataset_task_id': '${stage_process.id}'}) parameter_override={'General/dataset_task_id': '${stage_process.id}'})

View File

@ -9,7 +9,7 @@ from plotly import graph_objects as go
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
from attr import attrib, attrs from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable from typing import Sequence, Optional, Mapping, Callable, Any
from trains import Task from trains import Task
from trains.automation import TrainsJob from trains.automation import TrainsJob
@ -45,6 +45,7 @@ class PipelineController(object):
pipeline_time_limit=None, # type: Optional[float] pipeline_time_limit=None, # type: Optional[float]
auto_connect_task=True, # type: bool auto_connect_task=True, # type: bool
always_create_task=False, # type: bool always_create_task=False, # type: bool
add_pipeline_tags=False, # type: bool
): ):
# type: (...) -> () # type: (...) -> ()
""" """
@ -54,6 +55,19 @@ class PipelineController(object):
:param str default_execution_queue: The execution queue to use if no execution queue is provided :param str default_execution_queue: The execution queue to use if no execution queue is provided
:param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The
default is ``None``, indicating no time limit. default is ``None``, indicating no time limit.
:param bool auto_connect_task: Store pipeline arguments and configuration in the Task
- ``True`` - The pipeline argument and configuration will be stored in the Task. All arguments will
be under the hyper-parameter section as ``opt/<arg>``, and the hyper_parameters will stored in the
Task ``connect_configuration`` (see artifacts/hyper-parameter).
- ``False`` - Do not store with Task.
:param bool always_create_task: Always create a new Task
- ``True`` - No current Task initialized. Create a new task named ``Pipeline`` in the ``base_task_id``
project.
- ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline.
""" """
self._nodes = {} self._nodes = {}
self._running_nodes = [] self._running_nodes = []
@ -64,6 +78,7 @@ class PipelineController(object):
self._thread = None self._thread = None
self._stop_event = None self._stop_event = None
self._experiment_created_cb = None self._experiment_created_cb = None
self._add_pipeline_tags = add_pipeline_tags
self._task = Task.current_task() self._task = Task.current_task()
self._step_ref_pattern = re.compile(self._step_pattern) self._step_ref_pattern = re.compile(self._step_pattern)
if not self._task and always_create_task: if not self._task and always_create_task:
@ -83,7 +98,7 @@ class PipelineController(object):
name, # type: str name, # type: str
base_task_id=None, # type: Optional[str] base_task_id=None, # type: Optional[str]
parents=None, # type: Optional[Sequence[str]] parents=None, # type: Optional[Sequence[str]]
parameter_override=None, # type: Optional[Mapping[str, str]] parameter_override=None, # type: Optional[Mapping[str, Any]]
execution_queue=None, # type: Optional[str] execution_queue=None, # type: Optional[str]
time_limit=None, # type: Optional[float] time_limit=None, # type: Optional[float]
base_task_project=None, # type: Optional[str] base_task_project=None, # type: Optional[str]
@ -139,7 +154,8 @@ class PipelineController(object):
self._nodes[name] = self.Node( self._nodes[name] = self.Node(
name=name, base_task_id=base_task_id, parents=parents or [], name=name, base_task_id=base_task_id, parents=parents or [],
queue=execution_queue, timeout=time_limit, parameters=parameter_override or {}) queue=execution_queue, timeout=time_limit,
parameters=parameter_override or {})
return True return True
def start(self, run_remotely=False, step_task_created_callback=None): def start(self, run_remotely=False, step_task_created_callback=None):
@ -171,13 +187,16 @@ class PipelineController(object):
pipeline_dag = self._serialize() pipeline_dag = self._serialize()
self._task.connect_configuration(pipeline_dag, name=self._config_section) self._task.connect_configuration(pipeline_dag, name=self._config_section)
params = {'continue_pipeline': False, params = {'continue_pipeline': False,
'default_queue': self._default_execution_queue} 'default_queue': self._default_execution_queue,
'add_pipeline_tags': self._add_pipeline_tags,
}
self._task.connect(params, name=self._config_section) self._task.connect(params, name=self._config_section)
# deserialize back pipeline state # deserialize back pipeline state
if not params['continue_pipeline']: if not params['continue_pipeline']:
for k in pipeline_dag: for k in pipeline_dag:
pipeline_dag[k]['executed'] = None pipeline_dag[k]['executed'] = None
self._default_execution_queue = params['default_queue'] self._default_execution_queue = params['default_queue']
self._add_pipeline_tags = params['add_pipeline_tags']
self._deserialize(pipeline_dag) self._deserialize(pipeline_dag)
if not self._verify(): if not self._verify():
@ -355,8 +374,9 @@ class PipelineController(object):
pattern = self._step_ref_pattern pattern = self._step_ref_pattern
for v in node.parameters.values(): for v in node.parameters.values():
for g in pattern.findall(v): if isinstance(v, str):
self.__verify_step_reference(node, g) for g in pattern.findall(v):
self.__verify_step_reference(node, g)
return True return True
@ -395,7 +415,8 @@ class PipelineController(object):
node.job = TrainsJob( node.job = TrainsJob(
base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters, base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
parent=self._task.id) tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None,
parent=self._task.id if self._task else None)
if self._experiment_created_cb: if self._experiment_created_cb:
self._experiment_created_cb(node, updated_hyper_parameters) self._experiment_created_cb(node, updated_hyper_parameters)
node.job.launch(queue_name=node.queue or self._default_execution_queue) node.job.launch(queue_name=node.queue or self._default_execution_queue)
@ -706,7 +727,7 @@ class PipelineController(object):
return None return None
def _parse_step_ref(self, value): def _parse_step_ref(self, value):
# type: (str) -> Optional[str] # type: (Any) -> Optional[str]
""" """
Return the step reference. For example "${step1.parameters.Args/param}" Return the step reference. For example "${step1.parameters.Args/param}"
:param value: string :param value: string
@ -715,8 +736,9 @@ class PipelineController(object):
# look for all the step references # look for all the step references
pattern = self._step_ref_pattern pattern = self._step_ref_pattern
updated_value = value updated_value = value
for g in pattern.findall(value): if isinstance(value, str):
# update with actual value for g in pattern.findall(value):
new_val = self.__parse_step_reference(g) # update with actual value
updated_value = updated_value.replace(g, new_val, 1) new_val = self.__parse_step_reference(g)
updated_value = updated_value.replace(g, new_val, 1)
return updated_value return updated_value