From 5beecbb078853a99e04633e550eb47c8ea1c3448 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 1 Sep 2020 18:00:19 +0300 Subject: [PATCH] Add Pipeline Controller automation and example --- examples/pipeline/pipeline_controller.py | 24 + examples/pipeline/step1_dataset_artifact.py | 19 + examples/pipeline/step2_data_processing.py | 55 ++ examples/pipeline/step3_train_model.py | 56 ++ trains/automation/controller.py | 722 ++++++++++++++++++++ trains/binding/matplotlib_bind.py | 2 +- trains/task.py | 14 +- 7 files changed, 884 insertions(+), 8 deletions(-) create mode 100644 examples/pipeline/pipeline_controller.py create mode 100644 examples/pipeline/step1_dataset_artifact.py create mode 100644 examples/pipeline/step2_data_processing.py create mode 100644 examples/pipeline/step3_train_model.py create mode 100644 trains/automation/controller.py diff --git a/examples/pipeline/pipeline_controller.py b/examples/pipeline/pipeline_controller.py new file mode 100644 index 00000000..4725a49a --- /dev/null +++ b/examples/pipeline/pipeline_controller.py @@ -0,0 +1,24 @@ +from trains import Task +from trains.automation.controller import PipelineController + + +task = Task.init(project_name='examples', task_name='pipeline demo', task_type=Task.TaskTypes.controller) + +pipe = PipelineController(default_execution_queue='default') +pipe.add_step(name='stage_data', base_task_project='examples', base_task_name='pipeline step 1 dataset artifact') +pipe.add_step(name='stage_process', parents=['stage_data', ], + base_task_project='examples', base_task_name='pipeline step 2 process dataset', + parameter_override={'General/dataset_url': '${stage_data.artifacts.dataset.url}', + 'General/test_size': '0.25'}) +pipe.add_step(name='stage_train', parents=['stage_process', ], + base_task_project='examples', base_task_name='pipeline step 3 train model', + parameter_override={'General/dataset_task_id': '${stage_process.id}'}) + +# Starting the pipeline (in the background) +pipe.start() +# Wait until pipeline terminates +pipe.wait() +# cleanup everything +pipe.stop() + +print('done') diff --git a/examples/pipeline/step1_dataset_artifact.py b/examples/pipeline/step1_dataset_artifact.py new file mode 100644 index 00000000..f1e7a6a3 --- /dev/null +++ b/examples/pipeline/step1_dataset_artifact.py @@ -0,0 +1,19 @@ +from trains import Task, StorageManager + +# create an dataset experiment +task = Task.init(project_name="examples", task_name="pipeline step 1 dataset artifact") + +# only create the task, we will actually execute it later +task.execute_remotely() + +# simulate local dataset, download one, so we have something local +local_iris_pkl = StorageManager.get_local_copy( + remote_url='https://github.com/allegroai/events/raw/master/odsc20-east/generic/iris_dataset.pkl') + +# add and upload local file containing our toy dataset +task.upload_artifact('dataset', artifact_object=local_iris_pkl) + +print('uploading artifacts in the background') + +# we are done +print('Done') diff --git a/examples/pipeline/step2_data_processing.py b/examples/pipeline/step2_data_processing.py new file mode 100644 index 00000000..ff74c89e --- /dev/null +++ b/examples/pipeline/step2_data_processing.py @@ -0,0 +1,55 @@ +import pickle +from trains import Task, StorageManager +from sklearn.model_selection import train_test_split + + +# Connecting TRAINS +task = Task.init(project_name="examples", task_name="pipeline step 2 process dataset") + +# program arguments +# Use either dataset_task_id to point to a tasks artifact or +# use a direct url with dataset_url +args = { + 'dataset_task_id': '', + 'dataset_url': '', + 'random_state': 42, + 'test_size': 0.2, +} + +# store arguments, later we will be able to change them from outside the code +task.connect(args) +print('Arguments: {}'.format(args)) + +# only create the task, we will actually execute it later +task.execute_remotely() + +# get dataset from task's artifact +if args['dataset_task_id']: + dataset_upload_task = Task.get_task(task_id=args['dataset_task_id']) + print('Input task id={} artifacts {}'.format(args['dataset_task_id'], list(dataset_upload_task.artifacts.keys()))) + # download the artifact + iris_pickle = dataset_upload_task.artifacts['dataset'].get_local_copy() +# get the dataset from a direct url +elif args['dataset_url']: + iris_pickle = StorageManager.get_local_copy(remote_url=args['dataset_url']) +else: + raise ValueError("Missing dataset link") + +# open the local copy +iris = pickle.load(open(iris_pickle, 'rb')) + +# "process" data +X = iris.data +y = iris.target +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=args['test_size'], random_state=args['random_state']) + +# upload processed data +print('Uploading process dataset') +task.upload_artifact('X_train', X_train) +task.upload_artifact('X_test', X_test) +task.upload_artifact('y_train', y_train) +task.upload_artifact('y_test', y_test) + +print('Notice, artifacts are uploaded in the background') +print('Done') diff --git a/examples/pipeline/step3_train_model.py b/examples/pipeline/step3_train_model.py new file mode 100644 index 00000000..332532b4 --- /dev/null +++ b/examples/pipeline/step3_train_model.py @@ -0,0 +1,56 @@ +import joblib +import matplotlib.pyplot as plt +import numpy as np +from sklearn.linear_model import LogisticRegression + +from trains import Task + +# Connecting TRAINS +task = Task.init(project_name="examples", task_name="pipeline step 3 train model") + +# Arguments +args = { + 'dataset_task_id': 'REPLACE_WITH_DATASET_TASK_ID', +} +task.connect(args) + +# only create the task, we will actually execute it later +task.execute_remotely() + +print('Retrieving Iris dataset') +dataset_task = Task.get_task(task_id=args['dataset_task_id']) +X_train = dataset_task.artifacts['X_train'].get() +X_test = dataset_task.artifacts['X_test'].get() +y_train = dataset_task.artifacts['y_train'].get() +y_test = dataset_task.artifacts['y_test'].get() +print('Iris dataset loaded') + +model = LogisticRegression(solver='liblinear', multi_class='auto') +model.fit(X_train, y_train) + +joblib.dump(model, 'model.pkl', compress=True) + +loaded_model = joblib.load('model.pkl') +result = loaded_model.score(X_test, y_test) + +print('model trained & stored') + +x_min, x_max = X_test[:, 0].min() - .5, X_test[:, 0].max() + .5 +y_min, y_max = X_test[:, 1].min() - .5, X_test[:, 1].max() + .5 +h = .02 # step size in the mesh +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) +plt.figure(1, figsize=(4, 3)) + +plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', cmap=plt.cm.Paired) +plt.xlabel('Sepal length') +plt.ylabel('Sepal width') + +plt.xlim(xx.min(), xx.max()) +plt.ylim(yy.min(), yy.max()) +plt.xticks(()) +plt.yticks(()) + +plt.title('Iris Types') +plt.show() + +print('Done') diff --git a/trains/automation/controller.py b/trains/automation/controller.py new file mode 100644 index 00000000..f9925b16 --- /dev/null +++ b/trains/automation/controller.py @@ -0,0 +1,722 @@ +import re +from copy import copy +from datetime import datetime +from logging import getLogger +from threading import Thread, Event +from time import time + +from plotly import graph_objects as go +from plotly.subplots import make_subplots + +from attr import attrib, attrs +from typing import Sequence, Optional, Mapping, Callable + +from trains import Task +from trains.automation import TrainsJob +from trains.model import BaseModel + + +class PipelineController(object): + """ + Pipeline controller. + Pipeline is a DAG of base tasks, each task will be cloned (arguments changed as required) executed and monitored + The pipeline process (task) itself can be executed manually or by the trains-agent services queue. + Notice: The pipeline controller lives as long as the pipeline itself is being executed. + """ + _tag = 'pipeline' + _step_pattern = r"\${[^}]*}" + _config_section = 'Pipeline' + + @attrs + class Node(object): + name = attrib(type=str) + base_task_id = attrib(type=str) + queue = attrib(type=str, default=None) + parents = attrib(type=list, default=[]) + timeout = attrib(type=float, default=None) + parameters = attrib(type=dict, default={}) + executed = attrib(type=str, default=None) + job = attrib(type=TrainsJob, default=None) + + def __init__( + self, + pool_frequency=0.2, # type: float + default_execution_queue=None, # type: Optional[str] + pipeline_time_limit=None, # type: Optional[float] + auto_connect_task=True, # type: bool + always_create_task=False, # type: bool + ): + # type: (...) -> () + """ + Create a new pipeline controller. The newly created object will launch and monitor the new experiments. + + :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states. + :param str default_execution_queue: The execution queue to use if no execution queue is provided + :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The + default is ``None``, indicating no time limit. + """ + self._nodes = {} + self._running_nodes = [] + self._start_time = None + self._pipeline_time_limit = pipeline_time_limit * 60. if pipeline_time_limit else None + self._default_execution_queue = default_execution_queue + self._pool_frequency = pool_frequency * 60. + self._thread = None + self._stop_event = None + self._experiment_created_cb = None + self._task = Task.current_task() + self._step_ref_pattern = re.compile(self._step_pattern) + if not self._task and always_create_task: + self._task = Task.init( + project_name='Pipelines', + task_name='Pipeline {}'.format(datetime.now()), + task_type=Task.TaskTypes.controller, + ) + + # make sure all the created tasks are our children, as we are creating them + if self._task: + self._task.add_tags([self._tag]) + self._auto_connect_task = auto_connect_task + + def add_step( + self, + name, # type: str + base_task_id=None, # type: Optional[str] + parents=None, # type: Optional[Sequence[str]] + parameter_override=None, # type: Optional[Mapping[str, str]] + execution_queue=None, # type: Optional[str] + time_limit=None, # type: Optional[float] + base_task_project=None, # type: Optional[str] + base_task_name=None, # type: Optional[str] + ): + # type: (...) -> bool + """ + Add a step to the pipeline execution DAG. + Each step must have a unique name (this name will later be used to address the step) + + :param str name: Unique of the step. For example `stage1` + :param str base_task_id: The Task ID to use for the step. Each time the step is executed, + the base Task is cloned, then the cloned task will be sent for execution. + :param list parents: Optional list of parent nodes in the DAG. + The current step in the pipeline will be sent for execution only after all the parent nodes + have been executed successfully. + :param dict parameter_override: Optional parameter overriding dictionary. + The dict values can reference a previously executed step using the following form '${step_name}' + Examples: + Artifact access + parameter_override={'Args/input_file': '${stage1.artifacts.mydata.url}' } + Model access (last model used) + parameter_override={'Args/input_file': '${stage1.models.output.-1.url}' } + Parameter access + parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' } + Task ID + parameter_override={'Args/input_file': '${stage3.id}' } + :param str execution_queue: Optional, the queue to use for executing this specific step. + If not provided, the task will be sent to the default execution queue, as defined on the class + :param float time_limit: Default None, no time limit. + Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed. + :param str base_task_project: If base_task_id is not given, + use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step. + :param str base_task_name: If base_task_id is not given, + use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step. + :return: True if successful + """ + # when running remotely do nothing, we will deserialize ourselves when we start + if self._task and not self._task.running_locally() and self._task.is_main_task(): + return True + + if name in self._nodes: + raise ValueError('Node named \'{}\' already exists in the pipeline dag'.format(name)) + + if not base_task_id: + if not base_task_project or not base_task_name: + raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided') + base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name) + if not base_task: + raise ValueError('Could not find base_task_project={} base_task_name={}'.format( + base_task_project, base_task_name)) + base_task_id = base_task.id + + self._nodes[name] = self.Node( + name=name, base_task_id=base_task_id, parents=parents or [], + queue=execution_queue, timeout=time_limit, parameters=parameter_override or {}) + return True + + def start(self, run_remotely=False, step_task_created_callback=None): + # type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool + """ + Start the pipeline controller. + If the calling process is stopped, then the controller stops as well. + + :param bool run_remotely: (default False), If True stop the current process and continue execution + on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services' + :param Callable step_task_created_callback: Callback function, called when a step (Task) is created + and before it is sent for execution. + + .. code-block:: py + + def step_created_callback( + node, # type: PipelineController.Node, + parameters, # type: dict + ): + pass + + :return: True, if the controller started. False, if the controller did not start. + + """ + if self._thread: + return True + + # serialize pipeline state + pipeline_dag = self._serialize() + self._task.connect_configuration(pipeline_dag, name=self._config_section) + params = {'continue_pipeline': False, + 'default_queue': self._default_execution_queue} + self._task.connect(params, name=self._config_section) + # deserialize back pipeline state + if not params['continue_pipeline']: + for k in pipeline_dag: + pipeline_dag[k]['executed'] = None + self._default_execution_queue = params['default_queue'] + self._deserialize(pipeline_dag) + + if not self._verify(): + raise ValueError("Failed verifying pipeline execution graph, " + "it has either inaccessible nodes, or contains cycles") + + self._update_execution_plot() + + if run_remotely: + self._task.execute_remotely(queue_name='services') + # we will not get here if we are not running remotely + + self._start_time = time() + self._stop_event = Event() + self._experiment_created_cb = step_task_created_callback + self._thread = Thread(target=self._daemon) + self._thread.daemon = True + self._thread.start() + return True + + def stop(self, timeout=None): + # type: (Optional[float]) -> () + """ + Stop the pipeline controller and the optimization thread. + + :param float timeout: Wait timeout for the optimization thread to exit (minutes). + The default is ``None``, indicating do not wait terminate immediately. + """ + pass + + def wait(self, timeout=None): + # type: (Optional[float]) -> bool + """ + Wait for the pipeline to finish. + + .. note:: + This method does not stop the pipeline. Call :meth:`stop` to terminate the pipeline. + + :param float timeout: The timeout to wait for the pipeline to complete (minutes). + If ``None``, then wait until we reached the timeout, or pipeline completed. + + :return: True, if the pipeline finished. False, if the pipeline timed out. + + """ + if not self.is_running(): + return True + + if timeout is not None: + timeout *= 60. + + _thread = self._thread + + _thread.join(timeout=timeout) + if _thread.is_alive(): + return False + + return True + + def is_running(self): + # type: () -> bool + """ + return True if the pipeline controller is running. + + :return: A boolean indicating whether the pipeline controller is active (still running) or stopped. + """ + return self._thread is not None + + def elapsed(self): + # type: () -> float + """ + Return minutes elapsed from controller stating time stamp. + + :return: The minutes from controller start time. A negative value means the process has not started yet. + """ + if self._start_time is None: + return -1.0 + return (time() - self._start_time) / 60. + + def get_pipeline_dag(self): + # type: () -> Mapping[str, PipelineController.Node] + """ + Return the pipeline execution graph, each node in the DAG is PipelineController.Node object. + Graph itself is a dictionary of Nodes (key based on the Node name), + each node holds links to its parent Nodes (identified by their unique names) + + :return: execution tree, as a nested dictionary + Example: + { + 'stage1' : Node() { + name: 'stage1' + job: TrainsJob + ... + }, + } + """ + return self._nodes + + def get_processed_nodes(self): + # type: () -> Sequence[PipelineController.Node] + """ + Return the a list of the processed pipeline nodes, each entry in the list is PipelineController.Node object. + + :return: executed (excluding currently executing) nodes list + """ + return {k: n for k, n in self._nodes.items() if n.executed} + + def get_running_nodes(self): + # type: () -> Sequence[PipelineController.Node] + """ + Return the a list of the currently running pipeline nodes, + each entry in the list is PipelineController.Node object. + + :return: Currently running nodes list + """ + return {k: n for k, n in self._nodes.items() if k in self._running_nodes} + + def _serialize(self): + # type: () -> dict + """ + Store the definition of the pipeline DAG into a dictionary. + This dictionary will be used to store the DAG as a configuration on the Task + :return: + """ + dag = {name: dict((k, v) for k, v in node.__dict__.items() if k not in ('job', 'name')) + for name, node in self._nodes.items()} + + return dag + + def _deserialize(self, dag_dict): + # type: (dict) -> () + """ + Restore the DAG from a dictionary. + This will be used to create the DAG from the dict stored on the Task, when running remotely. + :return: + """ + self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()} + + def _verify(self): + # type: () -> bool + """ + Verify the DAG, (i.e. no cycles and no missing parents) + On error raise ValueError with verification details + + :return: return True iff DAG has no errors + """ + # verify nodes + for node in self._nodes.values(): + # raise value error if not verified + self._verify_node(node) + + # check the dag itself + if not self._verify_dag(): + return False + + return True + + def _verify_node(self, node): + # type: (Node) -> bool + """ + Raise ValueError on verification errors + + :return: Return True iff the specific node is verified + """ + if not node.base_task_id: + raise ValueError("Node '{}', base_task_id is empty".format(node.name)) + + if not self._default_execution_queue and not node.queue: + raise ValueError("Node '{}' missing execution queue, " + "no default queue defined and no specific node queue defined".format(node.name)) + + task = Task.get_task(task_id=node.base_task_id) + if not task: + raise ValueError("Node '{}', base_task_id={} is invalid".format(node.name, node.base_task_id)) + + pattern = self._step_ref_pattern + + for v in node.parameters.values(): + for g in pattern.findall(v): + self.__verify_step_reference(node, g) + + return True + + def _verify_dag(self): + # type: () -> bool + """ + :return: True iff the pipeline dag is fully accessible and contains no cycles + """ + visited = set() + prev_visited = None + while prev_visited != visited: + prev_visited = copy(visited) + for k, node in self._nodes.items(): + if k in visited: + continue + if not all(p in visited for p in node.parents or []): + continue + visited.add(k) + # return False if we did not cover all the nodes + return not bool(set(self._nodes.keys()) - visited) + + def _launch_node(self, node): + # type: (Node) -> () + """ + Launch a single node (create and enqueue a TrainsJob) + + :param node: Node to launch + :return: Return True if a new job was launched + """ + if node.job or node.executed: + return False + + updated_hyper_parameters = {} + for k, v in node.parameters.items(): + updated_hyper_parameters[k] = self._parse_step_ref(v) + + node.job = TrainsJob( + base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters, + parent=self._task.id) + if self._experiment_created_cb: + self._experiment_created_cb(node, updated_hyper_parameters) + node.job.launch(queue_name=node.queue or self._default_execution_queue) + return True + + def _update_execution_plot(self): + # type: () -> () + """ + Update sankey diagram of the current pipeline + """ + sankey_node = dict( + label=[], + color=[], + hovertemplate='%{label}', + # customdata=[], + # hovertemplate='%{label}
Hyper-Parameters:
%{customdata}', + ) + sankey_link = dict( + source=[], + target=[], + value=[], + hovertemplate='%{target.label}', + ) + visited = [] + node_params = [] + nodes = list(self._nodes.values()) + while nodes: + next_nodes = [] + for node in nodes: + if not all(p in visited for p in node.parents or []): + next_nodes.append(node) + continue + visited.append(node.name) + idx = len(visited) - 1 + parents = [visited.index(p) for p in node.parents or []] + node_params.append(node.job.task_parameter_override if node.job else node.parameters) or {} + # sankey_node['label'].append(node.name) + # sankey_node['customdata'].append( + # '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) + sankey_node['label'].append( + '{}
'.format(node.name) + + '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) + sankey_node['color'].append( + ("blue" if not node.job or not node.job.is_failed() else "red") + if node.executed else ("green" if node.job else "lightsteelblue")) + + for p in parents: + sankey_link['source'].append(p) + sankey_link['target'].append(idx) + sankey_link['value'].append(1) + + nodes = next_nodes + + # make sure we have no independent (unconnected) nodes + for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]: + sankey_link['source'].append(i) + sankey_link['target'].append(i) + sankey_link['value'].append(0.1) + + fig = make_subplots( + rows=2, cols=1, + shared_xaxes=True, + vertical_spacing=0.03, + specs=[[{"type": "table"}], + [{"type": "sankey"}], ] + ) + # noinspection PyUnresolvedReferences + fig.add_trace( + go.Sankey( + node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) + ), + row=1, col=1 + ) + # noinspection PyUnresolvedReferences + fig.add_trace( + go.Table( + header=dict( + values=["Pipeline Step", "Task ID", "Parameters"], + align="left", + ), + cells=dict( + values=[visited, + [self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else '') + for v in visited], + [str(p) for p in node_params]], + align="left") + ), + row=2, col=1 + ) + + # fig = go.Figure(data=[go.Sankey( + # node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1))],) + self._task.get_logger().report_plotly( + title='Pipeline', series='execution flow', iteration=0, figure=fig) + + def _force_task_configuration_update(self): + pipeline_dag = self._serialize() + # noinspection PyProtectedMember + self._task._set_configuration( + name=self._config_section, config_type='dictionary', config_dict=pipeline_dag) + + def _daemon(self): + # type: () -> () + """ + The main pipeline execution loop. This loop is executed on its own dedicated thread. + :return: + """ + pooling_counter = 0 + + while self._stop_event: + # stop request + if pooling_counter and self._stop_event.wait(self._pool_frequency): + break + + pooling_counter += 1 + + # check the pipeline time limit + if self._pipeline_time_limit and (time() - self._start_time) > self._pipeline_time_limit: + break + + # check the state of all current jobs + # if no a job ended, continue + completed_jobs = [] + for j in self._running_nodes: + node = self._nodes[j] + if not node.job: + continue + if node.job.is_stopped(): + completed_jobs.append(j) + node.executed = node.job.task_id() + elif node.timeout: + started = node.job.task.data.started + if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout: + node.job.abort() + completed_jobs.append(j) + node.executed = node.job.task_id() + + # update running jobs + self._running_nodes = [j for j in self._running_nodes if j not in completed_jobs] + + # nothing changed, we can sleep + if not completed_jobs and self._running_nodes: + continue + + # Pull the next jobs in the pipeline, based on the completed list + next_nodes = [] + for node in self._nodes.values(): + # check if already processed. + if node.job or node.executed: + continue + completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []] + if all(completed_parents): + next_nodes.append(node.name) + + # update the execution graph + for name in next_nodes: + if self._launch_node(self._nodes[name]): + print('Launching step: {}'.format(name)) + print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override)) + self._running_nodes.append(name) + else: + getLogger('trains.automation.controller').error( + 'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name])) + + # update current state (in configuration, so that we could later continue an aborted pipeline) + self._force_task_configuration_update() + + # visualize pipeline state (plot) + self._update_execution_plot() + + # quit if all pipelines nodes are fully executed. + if not next_nodes and not self._running_nodes: + break + + # stop all currently running jobs: + for node in self._nodes.values(): + if node.job and not node.executed: + node.job.abort() + + if self._stop_event: + # noinspection PyBroadException + try: + self._stop_event.set() + except Exception: + pass + + def __verify_step_reference(self, node, step_ref_string): + # type: (Node, str) -> bool + """ + Verify the step reference. For example "${step1.parameters.Args/param}" + :param Node node: calling reference node (used for logging) + :param str step_ref_string: For example "${step1.parameters.Args/param}" + :return: True if valid reference + """ + parts = step_ref_string[2:-1].split('.') + v = step_ref_string + if len(parts) < 2: + raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v)) + prev_step = parts[0] + input_type = parts[1] + if prev_step not in self._nodes: + raise ValueError("Node '{}', parameter '{}', step name '{}' is invalid".format(node.name, v, prev_step)) + if input_type not in ('artifacts', 'parameters', 'models', 'id'): + raise ValueError( + "Node {}, parameter '{}', input type '{}' is invalid".format(node.name, v, input_type)) + + if input_type != 'id' and len(parts) < 3: + raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v)) + + if input_type == 'models': + try: + model_type = parts[2].lower() + except Exception: + raise ValueError( + "Node '{}', parameter '{}', input type '{}', model_type is missing {}".format( + node.name, v, input_type, parts)) + if model_type not in ('input', 'output'): + raise ValueError( + "Node '{}', parameter '{}', input type '{}', " + "model_type is invalid (input/output) found {}".format( + node.name, v, input_type, model_type)) + + if len(parts) < 4: + raise ValueError( + "Node '{}', parameter '{}', input type '{}', model index is missing".format( + node.name, v, input_type)) + + # check casting + try: + int(parts[3]) + except Exception: + raise ValueError( + "Node '{}', parameter '{}', input type '{}', model index is missing {}".format( + node.name, v, input_type, parts)) + + if len(parts) < 5: + raise ValueError( + "Node '{}', parameter '{}', input type '{}', model property is missing".format( + node.name, v, input_type)) + + if not hasattr(BaseModel, parts[4]): + raise ValueError( + "Node '{}', parameter '{}', input type '{}', model property is invalid {}".format( + node.name, v, input_type, parts[4])) + return True + + def __parse_step_reference(self, step_ref_string): + """ + return the adjusted value for "${step...}" + :param step_ref_string: reference string of the form ${step_name.type.value}" + :return: str with value + """ + parts = step_ref_string[2:-1].split('.') + if len(parts) < 2: + raise ValueError("Could not parse reference '{}'".format(step_ref_string)) + prev_step = parts[0] + input_type = parts[1].lower() + if prev_step not in self._nodes or not self._nodes[prev_step].job: + raise ValueError("Could not parse reference '{}', step {} could not be found".format( + step_ref_string, prev_step)) + if input_type not in ('artifacts', 'parameters', 'models', 'id'): + raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type)) + if input_type != 'id' and len(parts) < 3: + raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts)) + + task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \ + else Task.get_task(task_id=self._nodes[prev_step].executed) + task.reload() + if input_type == 'artifacts': + # fix \. to use . in artifacts + artifact_path = ('.'.join(parts[2:])).replace('\\.', '\\_dot_\\') + artifact_path = artifact_path.split('.') + + obj = task.artifacts + for p in artifact_path: + p = p.replace('\\_dot_\\', '.') + if isinstance(obj, dict): + obj = obj.get(p) + elif hasattr(obj, p): + obj = getattr(obj, p) + else: + raise ValueError("Could not locate artifact {} on previous step {}".format( + '.'.join(parts[1:]), prev_step)) + return str(obj) + elif input_type == 'parameters': + step_params = task.get_parameters() + param_name = '.'.join(parts[2:]) + if param_name not in step_params: + raise ValueError("Could not locate parameter {} on previous step {}".format( + '.'.join(parts[1:]), prev_step)) + return step_params.get(param_name) + elif input_type == 'models': + model_type = parts[2].lower() + if model_type not in ('input', 'output'): + raise ValueError("Could not locate model {} on previous step {}".format( + '.'.join(parts[1:]), prev_step)) + try: + model_idx = int(parts[3]) + model = task.models[model_type][model_idx] + except Exception: + raise ValueError("Could not locate model {} on previous step {}, index {} is invalid".format( + '.'.join(parts[1:]), prev_step, parts[3])) + + return str(getattr(model, parts[4])) + + elif input_type == 'id': + return task.id + return None + + def _parse_step_ref(self, value): + # type: (str) -> Optional[str] + """ + Return the step reference. For example "${step1.parameters.Args/param}" + :param value: string + :return: + """ + # look for all the step references + pattern = self._step_ref_pattern + updated_value = value + for g in pattern.findall(value): + # update with actual value + new_val = self.__parse_step_reference(g) + updated_value = updated_value.replace(g, new_val, 1) + return updated_value diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index d9c48428..d3d78ed6 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -384,7 +384,7 @@ class PatchedMatplotlib: _pylab_helpers.Gcf.set_active(stored_figure) # get the main task - reporter = PatchedMatplotlib._current_task._reporter + reporter = PatchedMatplotlib._current_task.__reporter if reporter is not None: if mpl_fig.texts: plot_title = mpl_fig.texts[0].get_text() diff --git a/trains/task.py b/trains/task.py index 4894f915..45da80ce 100644 --- a/trains/task.py +++ b/trains/task.py @@ -383,7 +383,7 @@ class Task(_Task): logger.set_flush_period(None) # create a new logger (to catch stdout/err) cls.__main_task._logger = None - cls.__main_task._reporter = None + cls.__main_task.__reporter = None cls.__main_task.get_logger() cls.__main_task._artifacts_manager = Artifacts(cls.__main_task) # unregister signal hooks, they cause subprocess to hang @@ -1165,8 +1165,8 @@ class Task(_Task): if self._logger: # noinspection PyProtectedMember self._logger._flush_stdout_handler() - if self._reporter: - self._reporter.flush() + if self.__reporter: + self.__reporter.flush() LoggerRoot.flush() return True @@ -1435,7 +1435,7 @@ class Task(_Task): :return: The last reported iteration number. """ self._reload_last_iteration() - return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0) + return max(self.data.last_iteration or 0, self.__reporter.max_iteration if self.__reporter else 0) def set_initial_iteration(self, offset=0): # type: (int) -> int @@ -2406,15 +2406,15 @@ class Task(_Task): # wait for uploads print_done_waiting = False if wait_for_uploads and (BackendModel.get_num_results() > 0 or - (self._reporter and self._reporter.get_num_results() > 0)): + (self.__reporter and self.__reporter.get_num_results() > 0)): self.log.info('Waiting to finish uploads') print_done_waiting = True # from here, do not send log in background thread if wait_for_uploads: self.flush(wait_for_uploads=True) # wait until the reporter flush everything - if self._reporter: - self._reporter.stop() + if self.__reporter: + self.__reporter.stop() if self.is_main_task(): # notice: this will close the reporting for all the Tasks in the system Metrics.close_async_threads()