diff --git a/examples/pipeline/pipeline_controller.py b/examples/pipeline/pipeline_controller.py
new file mode 100644
index 00000000..4725a49a
--- /dev/null
+++ b/examples/pipeline/pipeline_controller.py
@@ -0,0 +1,24 @@
+from trains import Task
+from trains.automation.controller import PipelineController
+
+
+task = Task.init(project_name='examples', task_name='pipeline demo', task_type=Task.TaskTypes.controller)
+
+pipe = PipelineController(default_execution_queue='default')
+pipe.add_step(name='stage_data', base_task_project='examples', base_task_name='pipeline step 1 dataset artifact')
+pipe.add_step(name='stage_process', parents=['stage_data', ],
+ base_task_project='examples', base_task_name='pipeline step 2 process dataset',
+ parameter_override={'General/dataset_url': '${stage_data.artifacts.dataset.url}',
+ 'General/test_size': '0.25'})
+pipe.add_step(name='stage_train', parents=['stage_process', ],
+ base_task_project='examples', base_task_name='pipeline step 3 train model',
+ parameter_override={'General/dataset_task_id': '${stage_process.id}'})
+
+# Starting the pipeline (in the background)
+pipe.start()
+# Wait until pipeline terminates
+pipe.wait()
+# cleanup everything
+pipe.stop()
+
+print('done')
diff --git a/examples/pipeline/step1_dataset_artifact.py b/examples/pipeline/step1_dataset_artifact.py
new file mode 100644
index 00000000..f1e7a6a3
--- /dev/null
+++ b/examples/pipeline/step1_dataset_artifact.py
@@ -0,0 +1,19 @@
+from trains import Task, StorageManager
+
+# create an dataset experiment
+task = Task.init(project_name="examples", task_name="pipeline step 1 dataset artifact")
+
+# only create the task, we will actually execute it later
+task.execute_remotely()
+
+# simulate local dataset, download one, so we have something local
+local_iris_pkl = StorageManager.get_local_copy(
+ remote_url='https://github.com/allegroai/events/raw/master/odsc20-east/generic/iris_dataset.pkl')
+
+# add and upload local file containing our toy dataset
+task.upload_artifact('dataset', artifact_object=local_iris_pkl)
+
+print('uploading artifacts in the background')
+
+# we are done
+print('Done')
diff --git a/examples/pipeline/step2_data_processing.py b/examples/pipeline/step2_data_processing.py
new file mode 100644
index 00000000..ff74c89e
--- /dev/null
+++ b/examples/pipeline/step2_data_processing.py
@@ -0,0 +1,55 @@
+import pickle
+from trains import Task, StorageManager
+from sklearn.model_selection import train_test_split
+
+
+# Connecting TRAINS
+task = Task.init(project_name="examples", task_name="pipeline step 2 process dataset")
+
+# program arguments
+# Use either dataset_task_id to point to a tasks artifact or
+# use a direct url with dataset_url
+args = {
+ 'dataset_task_id': '',
+ 'dataset_url': '',
+ 'random_state': 42,
+ 'test_size': 0.2,
+}
+
+# store arguments, later we will be able to change them from outside the code
+task.connect(args)
+print('Arguments: {}'.format(args))
+
+# only create the task, we will actually execute it later
+task.execute_remotely()
+
+# get dataset from task's artifact
+if args['dataset_task_id']:
+ dataset_upload_task = Task.get_task(task_id=args['dataset_task_id'])
+ print('Input task id={} artifacts {}'.format(args['dataset_task_id'], list(dataset_upload_task.artifacts.keys())))
+ # download the artifact
+ iris_pickle = dataset_upload_task.artifacts['dataset'].get_local_copy()
+# get the dataset from a direct url
+elif args['dataset_url']:
+ iris_pickle = StorageManager.get_local_copy(remote_url=args['dataset_url'])
+else:
+ raise ValueError("Missing dataset link")
+
+# open the local copy
+iris = pickle.load(open(iris_pickle, 'rb'))
+
+# "process" data
+X = iris.data
+y = iris.target
+X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=args['test_size'], random_state=args['random_state'])
+
+# upload processed data
+print('Uploading process dataset')
+task.upload_artifact('X_train', X_train)
+task.upload_artifact('X_test', X_test)
+task.upload_artifact('y_train', y_train)
+task.upload_artifact('y_test', y_test)
+
+print('Notice, artifacts are uploaded in the background')
+print('Done')
diff --git a/examples/pipeline/step3_train_model.py b/examples/pipeline/step3_train_model.py
new file mode 100644
index 00000000..332532b4
--- /dev/null
+++ b/examples/pipeline/step3_train_model.py
@@ -0,0 +1,56 @@
+import joblib
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.linear_model import LogisticRegression
+
+from trains import Task
+
+# Connecting TRAINS
+task = Task.init(project_name="examples", task_name="pipeline step 3 train model")
+
+# Arguments
+args = {
+ 'dataset_task_id': 'REPLACE_WITH_DATASET_TASK_ID',
+}
+task.connect(args)
+
+# only create the task, we will actually execute it later
+task.execute_remotely()
+
+print('Retrieving Iris dataset')
+dataset_task = Task.get_task(task_id=args['dataset_task_id'])
+X_train = dataset_task.artifacts['X_train'].get()
+X_test = dataset_task.artifacts['X_test'].get()
+y_train = dataset_task.artifacts['y_train'].get()
+y_test = dataset_task.artifacts['y_test'].get()
+print('Iris dataset loaded')
+
+model = LogisticRegression(solver='liblinear', multi_class='auto')
+model.fit(X_train, y_train)
+
+joblib.dump(model, 'model.pkl', compress=True)
+
+loaded_model = joblib.load('model.pkl')
+result = loaded_model.score(X_test, y_test)
+
+print('model trained & stored')
+
+x_min, x_max = X_test[:, 0].min() - .5, X_test[:, 0].max() + .5
+y_min, y_max = X_test[:, 1].min() - .5, X_test[:, 1].max() + .5
+h = .02 # step size in the mesh
+xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
+plt.figure(1, figsize=(4, 3))
+
+plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', cmap=plt.cm.Paired)
+plt.xlabel('Sepal length')
+plt.ylabel('Sepal width')
+
+plt.xlim(xx.min(), xx.max())
+plt.ylim(yy.min(), yy.max())
+plt.xticks(())
+plt.yticks(())
+
+plt.title('Iris Types')
+plt.show()
+
+print('Done')
diff --git a/trains/automation/controller.py b/trains/automation/controller.py
new file mode 100644
index 00000000..f9925b16
--- /dev/null
+++ b/trains/automation/controller.py
@@ -0,0 +1,722 @@
+import re
+from copy import copy
+from datetime import datetime
+from logging import getLogger
+from threading import Thread, Event
+from time import time
+
+from plotly import graph_objects as go
+from plotly.subplots import make_subplots
+
+from attr import attrib, attrs
+from typing import Sequence, Optional, Mapping, Callable
+
+from trains import Task
+from trains.automation import TrainsJob
+from trains.model import BaseModel
+
+
+class PipelineController(object):
+ """
+ Pipeline controller.
+ Pipeline is a DAG of base tasks, each task will be cloned (arguments changed as required) executed and monitored
+ The pipeline process (task) itself can be executed manually or by the trains-agent services queue.
+ Notice: The pipeline controller lives as long as the pipeline itself is being executed.
+ """
+ _tag = 'pipeline'
+ _step_pattern = r"\${[^}]*}"
+ _config_section = 'Pipeline'
+
+ @attrs
+ class Node(object):
+ name = attrib(type=str)
+ base_task_id = attrib(type=str)
+ queue = attrib(type=str, default=None)
+ parents = attrib(type=list, default=[])
+ timeout = attrib(type=float, default=None)
+ parameters = attrib(type=dict, default={})
+ executed = attrib(type=str, default=None)
+ job = attrib(type=TrainsJob, default=None)
+
+ def __init__(
+ self,
+ pool_frequency=0.2, # type: float
+ default_execution_queue=None, # type: Optional[str]
+ pipeline_time_limit=None, # type: Optional[float]
+ auto_connect_task=True, # type: bool
+ always_create_task=False, # type: bool
+ ):
+ # type: (...) -> ()
+ """
+ Create a new pipeline controller. The newly created object will launch and monitor the new experiments.
+
+ :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
+ :param str default_execution_queue: The execution queue to use if no execution queue is provided
+ :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The
+ default is ``None``, indicating no time limit.
+ """
+ self._nodes = {}
+ self._running_nodes = []
+ self._start_time = None
+ self._pipeline_time_limit = pipeline_time_limit * 60. if pipeline_time_limit else None
+ self._default_execution_queue = default_execution_queue
+ self._pool_frequency = pool_frequency * 60.
+ self._thread = None
+ self._stop_event = None
+ self._experiment_created_cb = None
+ self._task = Task.current_task()
+ self._step_ref_pattern = re.compile(self._step_pattern)
+ if not self._task and always_create_task:
+ self._task = Task.init(
+ project_name='Pipelines',
+ task_name='Pipeline {}'.format(datetime.now()),
+ task_type=Task.TaskTypes.controller,
+ )
+
+ # make sure all the created tasks are our children, as we are creating them
+ if self._task:
+ self._task.add_tags([self._tag])
+ self._auto_connect_task = auto_connect_task
+
+ def add_step(
+ self,
+ name, # type: str
+ base_task_id=None, # type: Optional[str]
+ parents=None, # type: Optional[Sequence[str]]
+ parameter_override=None, # type: Optional[Mapping[str, str]]
+ execution_queue=None, # type: Optional[str]
+ time_limit=None, # type: Optional[float]
+ base_task_project=None, # type: Optional[str]
+ base_task_name=None, # type: Optional[str]
+ ):
+ # type: (...) -> bool
+ """
+ Add a step to the pipeline execution DAG.
+ Each step must have a unique name (this name will later be used to address the step)
+
+ :param str name: Unique of the step. For example `stage1`
+ :param str base_task_id: The Task ID to use for the step. Each time the step is executed,
+ the base Task is cloned, then the cloned task will be sent for execution.
+ :param list parents: Optional list of parent nodes in the DAG.
+ The current step in the pipeline will be sent for execution only after all the parent nodes
+ have been executed successfully.
+ :param dict parameter_override: Optional parameter overriding dictionary.
+ The dict values can reference a previously executed step using the following form '${step_name}'
+ Examples:
+ Artifact access
+ parameter_override={'Args/input_file': '${stage1.artifacts.mydata.url}' }
+ Model access (last model used)
+ parameter_override={'Args/input_file': '${stage1.models.output.-1.url}' }
+ Parameter access
+ parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' }
+ Task ID
+ parameter_override={'Args/input_file': '${stage3.id}' }
+ :param str execution_queue: Optional, the queue to use for executing this specific step.
+ If not provided, the task will be sent to the default execution queue, as defined on the class
+ :param float time_limit: Default None, no time limit.
+ Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed.
+ :param str base_task_project: If base_task_id is not given,
+ use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
+ :param str base_task_name: If base_task_id is not given,
+ use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
+ :return: True if successful
+ """
+ # when running remotely do nothing, we will deserialize ourselves when we start
+ if self._task and not self._task.running_locally() and self._task.is_main_task():
+ return True
+
+ if name in self._nodes:
+ raise ValueError('Node named \'{}\' already exists in the pipeline dag'.format(name))
+
+ if not base_task_id:
+ if not base_task_project or not base_task_name:
+ raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
+ base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name)
+ if not base_task:
+ raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
+ base_task_project, base_task_name))
+ base_task_id = base_task.id
+
+ self._nodes[name] = self.Node(
+ name=name, base_task_id=base_task_id, parents=parents or [],
+ queue=execution_queue, timeout=time_limit, parameters=parameter_override or {})
+ return True
+
+ def start(self, run_remotely=False, step_task_created_callback=None):
+ # type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
+ """
+ Start the pipeline controller.
+ If the calling process is stopped, then the controller stops as well.
+
+ :param bool run_remotely: (default False), If True stop the current process and continue execution
+ on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'
+ :param Callable step_task_created_callback: Callback function, called when a step (Task) is created
+ and before it is sent for execution.
+
+ .. code-block:: py
+
+ def step_created_callback(
+ node, # type: PipelineController.Node,
+ parameters, # type: dict
+ ):
+ pass
+
+ :return: True, if the controller started. False, if the controller did not start.
+
+ """
+ if self._thread:
+ return True
+
+ # serialize pipeline state
+ pipeline_dag = self._serialize()
+ self._task.connect_configuration(pipeline_dag, name=self._config_section)
+ params = {'continue_pipeline': False,
+ 'default_queue': self._default_execution_queue}
+ self._task.connect(params, name=self._config_section)
+ # deserialize back pipeline state
+ if not params['continue_pipeline']:
+ for k in pipeline_dag:
+ pipeline_dag[k]['executed'] = None
+ self._default_execution_queue = params['default_queue']
+ self._deserialize(pipeline_dag)
+
+ if not self._verify():
+ raise ValueError("Failed verifying pipeline execution graph, "
+ "it has either inaccessible nodes, or contains cycles")
+
+ self._update_execution_plot()
+
+ if run_remotely:
+ self._task.execute_remotely(queue_name='services')
+ # we will not get here if we are not running remotely
+
+ self._start_time = time()
+ self._stop_event = Event()
+ self._experiment_created_cb = step_task_created_callback
+ self._thread = Thread(target=self._daemon)
+ self._thread.daemon = True
+ self._thread.start()
+ return True
+
+ def stop(self, timeout=None):
+ # type: (Optional[float]) -> ()
+ """
+ Stop the pipeline controller and the optimization thread.
+
+ :param float timeout: Wait timeout for the optimization thread to exit (minutes).
+ The default is ``None``, indicating do not wait terminate immediately.
+ """
+ pass
+
+ def wait(self, timeout=None):
+ # type: (Optional[float]) -> bool
+ """
+ Wait for the pipeline to finish.
+
+ .. note::
+ This method does not stop the pipeline. Call :meth:`stop` to terminate the pipeline.
+
+ :param float timeout: The timeout to wait for the pipeline to complete (minutes).
+ If ``None``, then wait until we reached the timeout, or pipeline completed.
+
+ :return: True, if the pipeline finished. False, if the pipeline timed out.
+
+ """
+ if not self.is_running():
+ return True
+
+ if timeout is not None:
+ timeout *= 60.
+
+ _thread = self._thread
+
+ _thread.join(timeout=timeout)
+ if _thread.is_alive():
+ return False
+
+ return True
+
+ def is_running(self):
+ # type: () -> bool
+ """
+ return True if the pipeline controller is running.
+
+ :return: A boolean indicating whether the pipeline controller is active (still running) or stopped.
+ """
+ return self._thread is not None
+
+ def elapsed(self):
+ # type: () -> float
+ """
+ Return minutes elapsed from controller stating time stamp.
+
+ :return: The minutes from controller start time. A negative value means the process has not started yet.
+ """
+ if self._start_time is None:
+ return -1.0
+ return (time() - self._start_time) / 60.
+
+ def get_pipeline_dag(self):
+ # type: () -> Mapping[str, PipelineController.Node]
+ """
+ Return the pipeline execution graph, each node in the DAG is PipelineController.Node object.
+ Graph itself is a dictionary of Nodes (key based on the Node name),
+ each node holds links to its parent Nodes (identified by their unique names)
+
+ :return: execution tree, as a nested dictionary
+ Example:
+ {
+ 'stage1' : Node() {
+ name: 'stage1'
+ job: TrainsJob
+ ...
+ },
+ }
+ """
+ return self._nodes
+
+ def get_processed_nodes(self):
+ # type: () -> Sequence[PipelineController.Node]
+ """
+ Return the a list of the processed pipeline nodes, each entry in the list is PipelineController.Node object.
+
+ :return: executed (excluding currently executing) nodes list
+ """
+ return {k: n for k, n in self._nodes.items() if n.executed}
+
+ def get_running_nodes(self):
+ # type: () -> Sequence[PipelineController.Node]
+ """
+ Return the a list of the currently running pipeline nodes,
+ each entry in the list is PipelineController.Node object.
+
+ :return: Currently running nodes list
+ """
+ return {k: n for k, n in self._nodes.items() if k in self._running_nodes}
+
+ def _serialize(self):
+ # type: () -> dict
+ """
+ Store the definition of the pipeline DAG into a dictionary.
+ This dictionary will be used to store the DAG as a configuration on the Task
+ :return:
+ """
+ dag = {name: dict((k, v) for k, v in node.__dict__.items() if k not in ('job', 'name'))
+ for name, node in self._nodes.items()}
+
+ return dag
+
+ def _deserialize(self, dag_dict):
+ # type: (dict) -> ()
+ """
+ Restore the DAG from a dictionary.
+ This will be used to create the DAG from the dict stored on the Task, when running remotely.
+ :return:
+ """
+ self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()}
+
+ def _verify(self):
+ # type: () -> bool
+ """
+ Verify the DAG, (i.e. no cycles and no missing parents)
+ On error raise ValueError with verification details
+
+ :return: return True iff DAG has no errors
+ """
+ # verify nodes
+ for node in self._nodes.values():
+ # raise value error if not verified
+ self._verify_node(node)
+
+ # check the dag itself
+ if not self._verify_dag():
+ return False
+
+ return True
+
+ def _verify_node(self, node):
+ # type: (Node) -> bool
+ """
+ Raise ValueError on verification errors
+
+ :return: Return True iff the specific node is verified
+ """
+ if not node.base_task_id:
+ raise ValueError("Node '{}', base_task_id is empty".format(node.name))
+
+ if not self._default_execution_queue and not node.queue:
+ raise ValueError("Node '{}' missing execution queue, "
+ "no default queue defined and no specific node queue defined".format(node.name))
+
+ task = Task.get_task(task_id=node.base_task_id)
+ if not task:
+ raise ValueError("Node '{}', base_task_id={} is invalid".format(node.name, node.base_task_id))
+
+ pattern = self._step_ref_pattern
+
+ for v in node.parameters.values():
+ for g in pattern.findall(v):
+ self.__verify_step_reference(node, g)
+
+ return True
+
+ def _verify_dag(self):
+ # type: () -> bool
+ """
+ :return: True iff the pipeline dag is fully accessible and contains no cycles
+ """
+ visited = set()
+ prev_visited = None
+ while prev_visited != visited:
+ prev_visited = copy(visited)
+ for k, node in self._nodes.items():
+ if k in visited:
+ continue
+ if not all(p in visited for p in node.parents or []):
+ continue
+ visited.add(k)
+ # return False if we did not cover all the nodes
+ return not bool(set(self._nodes.keys()) - visited)
+
+ def _launch_node(self, node):
+ # type: (Node) -> ()
+ """
+ Launch a single node (create and enqueue a TrainsJob)
+
+ :param node: Node to launch
+ :return: Return True if a new job was launched
+ """
+ if node.job or node.executed:
+ return False
+
+ updated_hyper_parameters = {}
+ for k, v in node.parameters.items():
+ updated_hyper_parameters[k] = self._parse_step_ref(v)
+
+ node.job = TrainsJob(
+ base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
+ parent=self._task.id)
+ if self._experiment_created_cb:
+ self._experiment_created_cb(node, updated_hyper_parameters)
+ node.job.launch(queue_name=node.queue or self._default_execution_queue)
+ return True
+
+ def _update_execution_plot(self):
+ # type: () -> ()
+ """
+ Update sankey diagram of the current pipeline
+ """
+ sankey_node = dict(
+ label=[],
+ color=[],
+ hovertemplate='%{label}',
+ # customdata=[],
+ # hovertemplate='%{label}
Hyper-Parameters:
%{customdata}',
+ )
+ sankey_link = dict(
+ source=[],
+ target=[],
+ value=[],
+ hovertemplate='%{target.label}',
+ )
+ visited = []
+ node_params = []
+ nodes = list(self._nodes.values())
+ while nodes:
+ next_nodes = []
+ for node in nodes:
+ if not all(p in visited for p in node.parents or []):
+ next_nodes.append(node)
+ continue
+ visited.append(node.name)
+ idx = len(visited) - 1
+ parents = [visited.index(p) for p in node.parents or []]
+ node_params.append(node.job.task_parameter_override if node.job else node.parameters) or {}
+ # sankey_node['label'].append(node.name)
+ # sankey_node['customdata'].append(
+ # '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
+ sankey_node['label'].append(
+ '{}
'.format(node.name) +
+ '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
+ sankey_node['color'].append(
+ ("blue" if not node.job or not node.job.is_failed() else "red")
+ if node.executed else ("green" if node.job else "lightsteelblue"))
+
+ for p in parents:
+ sankey_link['source'].append(p)
+ sankey_link['target'].append(idx)
+ sankey_link['value'].append(1)
+
+ nodes = next_nodes
+
+ # make sure we have no independent (unconnected) nodes
+ for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]:
+ sankey_link['source'].append(i)
+ sankey_link['target'].append(i)
+ sankey_link['value'].append(0.1)
+
+ fig = make_subplots(
+ rows=2, cols=1,
+ shared_xaxes=True,
+ vertical_spacing=0.03,
+ specs=[[{"type": "table"}],
+ [{"type": "sankey"}], ]
+ )
+ # noinspection PyUnresolvedReferences
+ fig.add_trace(
+ go.Sankey(
+ node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
+ ),
+ row=1, col=1
+ )
+ # noinspection PyUnresolvedReferences
+ fig.add_trace(
+ go.Table(
+ header=dict(
+ values=["Pipeline Step", "Task ID", "Parameters"],
+ align="left",
+ ),
+ cells=dict(
+ values=[visited,
+ [self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else '')
+ for v in visited],
+ [str(p) for p in node_params]],
+ align="left")
+ ),
+ row=2, col=1
+ )
+
+ # fig = go.Figure(data=[go.Sankey(
+ # node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1))],)
+ self._task.get_logger().report_plotly(
+ title='Pipeline', series='execution flow', iteration=0, figure=fig)
+
+ def _force_task_configuration_update(self):
+ pipeline_dag = self._serialize()
+ # noinspection PyProtectedMember
+ self._task._set_configuration(
+ name=self._config_section, config_type='dictionary', config_dict=pipeline_dag)
+
+ def _daemon(self):
+ # type: () -> ()
+ """
+ The main pipeline execution loop. This loop is executed on its own dedicated thread.
+ :return:
+ """
+ pooling_counter = 0
+
+ while self._stop_event:
+ # stop request
+ if pooling_counter and self._stop_event.wait(self._pool_frequency):
+ break
+
+ pooling_counter += 1
+
+ # check the pipeline time limit
+ if self._pipeline_time_limit and (time() - self._start_time) > self._pipeline_time_limit:
+ break
+
+ # check the state of all current jobs
+ # if no a job ended, continue
+ completed_jobs = []
+ for j in self._running_nodes:
+ node = self._nodes[j]
+ if not node.job:
+ continue
+ if node.job.is_stopped():
+ completed_jobs.append(j)
+ node.executed = node.job.task_id()
+ elif node.timeout:
+ started = node.job.task.data.started
+ if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
+ node.job.abort()
+ completed_jobs.append(j)
+ node.executed = node.job.task_id()
+
+ # update running jobs
+ self._running_nodes = [j for j in self._running_nodes if j not in completed_jobs]
+
+ # nothing changed, we can sleep
+ if not completed_jobs and self._running_nodes:
+ continue
+
+ # Pull the next jobs in the pipeline, based on the completed list
+ next_nodes = []
+ for node in self._nodes.values():
+ # check if already processed.
+ if node.job or node.executed:
+ continue
+ completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []]
+ if all(completed_parents):
+ next_nodes.append(node.name)
+
+ # update the execution graph
+ for name in next_nodes:
+ if self._launch_node(self._nodes[name]):
+ print('Launching step: {}'.format(name))
+ print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override))
+ self._running_nodes.append(name)
+ else:
+ getLogger('trains.automation.controller').error(
+ 'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name]))
+
+ # update current state (in configuration, so that we could later continue an aborted pipeline)
+ self._force_task_configuration_update()
+
+ # visualize pipeline state (plot)
+ self._update_execution_plot()
+
+ # quit if all pipelines nodes are fully executed.
+ if not next_nodes and not self._running_nodes:
+ break
+
+ # stop all currently running jobs:
+ for node in self._nodes.values():
+ if node.job and not node.executed:
+ node.job.abort()
+
+ if self._stop_event:
+ # noinspection PyBroadException
+ try:
+ self._stop_event.set()
+ except Exception:
+ pass
+
+ def __verify_step_reference(self, node, step_ref_string):
+ # type: (Node, str) -> bool
+ """
+ Verify the step reference. For example "${step1.parameters.Args/param}"
+ :param Node node: calling reference node (used for logging)
+ :param str step_ref_string: For example "${step1.parameters.Args/param}"
+ :return: True if valid reference
+ """
+ parts = step_ref_string[2:-1].split('.')
+ v = step_ref_string
+ if len(parts) < 2:
+ raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v))
+ prev_step = parts[0]
+ input_type = parts[1]
+ if prev_step not in self._nodes:
+ raise ValueError("Node '{}', parameter '{}', step name '{}' is invalid".format(node.name, v, prev_step))
+ if input_type not in ('artifacts', 'parameters', 'models', 'id'):
+ raise ValueError(
+ "Node {}, parameter '{}', input type '{}' is invalid".format(node.name, v, input_type))
+
+ if input_type != 'id' and len(parts) < 3:
+ raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v))
+
+ if input_type == 'models':
+ try:
+ model_type = parts[2].lower()
+ except Exception:
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', model_type is missing {}".format(
+ node.name, v, input_type, parts))
+ if model_type not in ('input', 'output'):
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', "
+ "model_type is invalid (input/output) found {}".format(
+ node.name, v, input_type, model_type))
+
+ if len(parts) < 4:
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', model index is missing".format(
+ node.name, v, input_type))
+
+ # check casting
+ try:
+ int(parts[3])
+ except Exception:
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', model index is missing {}".format(
+ node.name, v, input_type, parts))
+
+ if len(parts) < 5:
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', model property is missing".format(
+ node.name, v, input_type))
+
+ if not hasattr(BaseModel, parts[4]):
+ raise ValueError(
+ "Node '{}', parameter '{}', input type '{}', model property is invalid {}".format(
+ node.name, v, input_type, parts[4]))
+ return True
+
+ def __parse_step_reference(self, step_ref_string):
+ """
+ return the adjusted value for "${step...}"
+ :param step_ref_string: reference string of the form ${step_name.type.value}"
+ :return: str with value
+ """
+ parts = step_ref_string[2:-1].split('.')
+ if len(parts) < 2:
+ raise ValueError("Could not parse reference '{}'".format(step_ref_string))
+ prev_step = parts[0]
+ input_type = parts[1].lower()
+ if prev_step not in self._nodes or not self._nodes[prev_step].job:
+ raise ValueError("Could not parse reference '{}', step {} could not be found".format(
+ step_ref_string, prev_step))
+ if input_type not in ('artifacts', 'parameters', 'models', 'id'):
+ raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type))
+ if input_type != 'id' and len(parts) < 3:
+ raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts))
+
+ task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \
+ else Task.get_task(task_id=self._nodes[prev_step].executed)
+ task.reload()
+ if input_type == 'artifacts':
+ # fix \. to use . in artifacts
+ artifact_path = ('.'.join(parts[2:])).replace('\\.', '\\_dot_\\')
+ artifact_path = artifact_path.split('.')
+
+ obj = task.artifacts
+ for p in artifact_path:
+ p = p.replace('\\_dot_\\', '.')
+ if isinstance(obj, dict):
+ obj = obj.get(p)
+ elif hasattr(obj, p):
+ obj = getattr(obj, p)
+ else:
+ raise ValueError("Could not locate artifact {} on previous step {}".format(
+ '.'.join(parts[1:]), prev_step))
+ return str(obj)
+ elif input_type == 'parameters':
+ step_params = task.get_parameters()
+ param_name = '.'.join(parts[2:])
+ if param_name not in step_params:
+ raise ValueError("Could not locate parameter {} on previous step {}".format(
+ '.'.join(parts[1:]), prev_step))
+ return step_params.get(param_name)
+ elif input_type == 'models':
+ model_type = parts[2].lower()
+ if model_type not in ('input', 'output'):
+ raise ValueError("Could not locate model {} on previous step {}".format(
+ '.'.join(parts[1:]), prev_step))
+ try:
+ model_idx = int(parts[3])
+ model = task.models[model_type][model_idx]
+ except Exception:
+ raise ValueError("Could not locate model {} on previous step {}, index {} is invalid".format(
+ '.'.join(parts[1:]), prev_step, parts[3]))
+
+ return str(getattr(model, parts[4]))
+
+ elif input_type == 'id':
+ return task.id
+ return None
+
+ def _parse_step_ref(self, value):
+ # type: (str) -> Optional[str]
+ """
+ Return the step reference. For example "${step1.parameters.Args/param}"
+ :param value: string
+ :return:
+ """
+ # look for all the step references
+ pattern = self._step_ref_pattern
+ updated_value = value
+ for g in pattern.findall(value):
+ # update with actual value
+ new_val = self.__parse_step_reference(g)
+ updated_value = updated_value.replace(g, new_val, 1)
+ return updated_value
diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py
index d9c48428..d3d78ed6 100644
--- a/trains/binding/matplotlib_bind.py
+++ b/trains/binding/matplotlib_bind.py
@@ -384,7 +384,7 @@ class PatchedMatplotlib:
_pylab_helpers.Gcf.set_active(stored_figure)
# get the main task
- reporter = PatchedMatplotlib._current_task._reporter
+ reporter = PatchedMatplotlib._current_task.__reporter
if reporter is not None:
if mpl_fig.texts:
plot_title = mpl_fig.texts[0].get_text()
diff --git a/trains/task.py b/trains/task.py
index 4894f915..45da80ce 100644
--- a/trains/task.py
+++ b/trains/task.py
@@ -383,7 +383,7 @@ class Task(_Task):
logger.set_flush_period(None)
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
- cls.__main_task._reporter = None
+ cls.__main_task.__reporter = None
cls.__main_task.get_logger()
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
# unregister signal hooks, they cause subprocess to hang
@@ -1165,8 +1165,8 @@ class Task(_Task):
if self._logger:
# noinspection PyProtectedMember
self._logger._flush_stdout_handler()
- if self._reporter:
- self._reporter.flush()
+ if self.__reporter:
+ self.__reporter.flush()
LoggerRoot.flush()
return True
@@ -1435,7 +1435,7 @@ class Task(_Task):
:return: The last reported iteration number.
"""
self._reload_last_iteration()
- return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0)
+ return max(self.data.last_iteration or 0, self.__reporter.max_iteration if self.__reporter else 0)
def set_initial_iteration(self, offset=0):
# type: (int) -> int
@@ -2406,15 +2406,15 @@ class Task(_Task):
# wait for uploads
print_done_waiting = False
if wait_for_uploads and (BackendModel.get_num_results() > 0 or
- (self._reporter and self._reporter.get_num_results() > 0)):
+ (self.__reporter and self.__reporter.get_num_results() > 0)):
self.log.info('Waiting to finish uploads')
print_done_waiting = True
# from here, do not send log in background thread
if wait_for_uploads:
self.flush(wait_for_uploads=True)
# wait until the reporter flush everything
- if self._reporter:
- self._reporter.stop()
+ if self.__reporter:
+ self.__reporter.stop()
if self.is_main_task():
# notice: this will close the reporting for all the Tasks in the system
Metrics.close_async_threads()