diff --git a/examples/pipeline/pipeline_controller.py b/examples/pipeline/pipeline_controller.py index 4725a49a..ac00d024 100644 --- a/examples/pipeline/pipeline_controller.py +++ b/examples/pipeline/pipeline_controller.py @@ -2,14 +2,15 @@ from trains import Task from trains.automation.controller import PipelineController -task = Task.init(project_name='examples', task_name='pipeline demo', task_type=Task.TaskTypes.controller) +task = Task.init(project_name='examples', task_name='pipeline demo', + task_type=Task.TaskTypes.controller, reuse_last_task_id=False) -pipe = PipelineController(default_execution_queue='default') +pipe = PipelineController(default_execution_queue='default', add_pipeline_tags=False) pipe.add_step(name='stage_data', base_task_project='examples', base_task_name='pipeline step 1 dataset artifact') pipe.add_step(name='stage_process', parents=['stage_data', ], base_task_project='examples', base_task_name='pipeline step 2 process dataset', parameter_override={'General/dataset_url': '${stage_data.artifacts.dataset.url}', - 'General/test_size': '0.25'}) + 'General/test_size': 0.25}) pipe.add_step(name='stage_train', parents=['stage_process', ], base_task_project='examples', base_task_name='pipeline step 3 train model', parameter_override={'General/dataset_task_id': '${stage_process.id}'}) diff --git a/trains/automation/controller.py b/trains/automation/controller.py index 3258c2b1..12a350c4 100644 --- a/trains/automation/controller.py +++ b/trains/automation/controller.py @@ -9,7 +9,7 @@ from plotly import graph_objects as go from plotly.subplots import make_subplots from attr import attrib, attrs -from typing import Sequence, Optional, Mapping, Callable +from typing import Sequence, Optional, Mapping, Callable, Any from trains import Task from trains.automation import TrainsJob @@ -45,6 +45,7 @@ class PipelineController(object): pipeline_time_limit=None, # type: Optional[float] auto_connect_task=True, # type: bool always_create_task=False, # type: bool + add_pipeline_tags=False, # type: bool ): # type: (...) -> () """ @@ -54,6 +55,19 @@ class PipelineController(object): :param str default_execution_queue: The execution queue to use if no execution queue is provided :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The default is ``None``, indicating no time limit. + :param bool auto_connect_task: Store pipeline arguments and configuration in the Task + - ``True`` - The pipeline argument and configuration will be stored in the Task. All arguments will + be under the hyper-parameter section as ``opt/``, and the hyper_parameters will stored in the + Task ``connect_configuration`` (see artifacts/hyper-parameter). + + - ``False`` - Do not store with Task. + :param bool always_create_task: Always create a new Task + - ``True`` - No current Task initialized. Create a new task named ``Pipeline`` in the ``base_task_id`` + project. + + - ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics. + :param bool add_pipeline_tags: (default: False) if True, add `pipe: ` tag to all + steps (Tasks) created by this pipeline. """ self._nodes = {} self._running_nodes = [] @@ -64,6 +78,7 @@ class PipelineController(object): self._thread = None self._stop_event = None self._experiment_created_cb = None + self._add_pipeline_tags = add_pipeline_tags self._task = Task.current_task() self._step_ref_pattern = re.compile(self._step_pattern) if not self._task and always_create_task: @@ -83,7 +98,7 @@ class PipelineController(object): name, # type: str base_task_id=None, # type: Optional[str] parents=None, # type: Optional[Sequence[str]] - parameter_override=None, # type: Optional[Mapping[str, str]] + parameter_override=None, # type: Optional[Mapping[str, Any]] execution_queue=None, # type: Optional[str] time_limit=None, # type: Optional[float] base_task_project=None, # type: Optional[str] @@ -139,7 +154,8 @@ class PipelineController(object): self._nodes[name] = self.Node( name=name, base_task_id=base_task_id, parents=parents or [], - queue=execution_queue, timeout=time_limit, parameters=parameter_override or {}) + queue=execution_queue, timeout=time_limit, + parameters=parameter_override or {}) return True def start(self, run_remotely=False, step_task_created_callback=None): @@ -171,13 +187,16 @@ class PipelineController(object): pipeline_dag = self._serialize() self._task.connect_configuration(pipeline_dag, name=self._config_section) params = {'continue_pipeline': False, - 'default_queue': self._default_execution_queue} + 'default_queue': self._default_execution_queue, + 'add_pipeline_tags': self._add_pipeline_tags, + } self._task.connect(params, name=self._config_section) # deserialize back pipeline state if not params['continue_pipeline']: for k in pipeline_dag: pipeline_dag[k]['executed'] = None self._default_execution_queue = params['default_queue'] + self._add_pipeline_tags = params['add_pipeline_tags'] self._deserialize(pipeline_dag) if not self._verify(): @@ -355,8 +374,9 @@ class PipelineController(object): pattern = self._step_ref_pattern for v in node.parameters.values(): - for g in pattern.findall(v): - self.__verify_step_reference(node, g) + if isinstance(v, str): + for g in pattern.findall(v): + self.__verify_step_reference(node, g) return True @@ -395,7 +415,8 @@ class PipelineController(object): node.job = TrainsJob( base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters, - parent=self._task.id) + tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None, + parent=self._task.id if self._task else None) if self._experiment_created_cb: self._experiment_created_cb(node, updated_hyper_parameters) node.job.launch(queue_name=node.queue or self._default_execution_queue) @@ -706,7 +727,7 @@ class PipelineController(object): return None def _parse_step_ref(self, value): - # type: (str) -> Optional[str] + # type: (Any) -> Optional[str] """ Return the step reference. For example "${step1.parameters.Args/param}" :param value: string @@ -715,8 +736,9 @@ class PipelineController(object): # look for all the step references pattern = self._step_ref_pattern updated_value = value - for g in pattern.findall(value): - # update with actual value - new_val = self.__parse_step_reference(g) - updated_value = updated_value.replace(g, new_val, 1) + if isinstance(value, str): + for g in pattern.findall(value): + # update with actual value + new_val = self.__parse_step_reference(g) + updated_value = updated_value.replace(g, new_val, 1) return updated_value