import re from copy import copy from datetime import datetime from logging import getLogger from threading import Thread, Event from time import time from plotly import graph_objects as go from plotly.subplots import make_subplots from attr import attrib, attrs from typing import Sequence, Optional, Mapping, Callable from trains import Task from trains.automation import TrainsJob from trains.model import BaseModel class PipelineController(object): """ Pipeline controller. Pipeline is a DAG of base tasks, each task will be cloned (arguments changed as required) executed and monitored The pipeline process (task) itself can be executed manually or by the trains-agent services queue. Notice: The pipeline controller lives as long as the pipeline itself is being executed. """ _tag = 'pipeline' _step_pattern = r"\${[^}]*}" _config_section = 'Pipeline' @attrs class Node(object): name = attrib(type=str) base_task_id = attrib(type=str) queue = attrib(type=str, default=None) parents = attrib(type=list, default=[]) timeout = attrib(type=float, default=None) parameters = attrib(type=dict, default={}) executed = attrib(type=str, default=None) job = attrib(type=TrainsJob, default=None) def __init__( self, pool_frequency=0.2, # type: float default_execution_queue=None, # type: Optional[str] pipeline_time_limit=None, # type: Optional[float] auto_connect_task=True, # type: bool always_create_task=False, # type: bool ): # type: (...) -> () """ Create a new pipeline controller. The newly created object will launch and monitor the new experiments. :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states. :param str default_execution_queue: The execution queue to use if no execution queue is provided :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The default is ``None``, indicating no time limit. """ self._nodes = {} self._running_nodes = [] self._start_time = None self._pipeline_time_limit = pipeline_time_limit * 60. if pipeline_time_limit else None self._default_execution_queue = default_execution_queue self._pool_frequency = pool_frequency * 60. self._thread = None self._stop_event = None self._experiment_created_cb = None self._task = Task.current_task() self._step_ref_pattern = re.compile(self._step_pattern) if not self._task and always_create_task: self._task = Task.init( project_name='Pipelines', task_name='Pipeline {}'.format(datetime.now()), task_type=Task.TaskTypes.controller, ) # make sure all the created tasks are our children, as we are creating them if self._task: self._task.add_tags([self._tag]) self._auto_connect_task = auto_connect_task def add_step( self, name, # type: str base_task_id=None, # type: Optional[str] parents=None, # type: Optional[Sequence[str]] parameter_override=None, # type: Optional[Mapping[str, str]] execution_queue=None, # type: Optional[str] time_limit=None, # type: Optional[float] base_task_project=None, # type: Optional[str] base_task_name=None, # type: Optional[str] ): # type: (...) -> bool """ Add a step to the pipeline execution DAG. Each step must have a unique name (this name will later be used to address the step) :param str name: Unique of the step. For example `stage1` :param str base_task_id: The Task ID to use for the step. Each time the step is executed, the base Task is cloned, then the cloned task will be sent for execution. :param list parents: Optional list of parent nodes in the DAG. The current step in the pipeline will be sent for execution only after all the parent nodes have been executed successfully. :param dict parameter_override: Optional parameter overriding dictionary. The dict values can reference a previously executed step using the following form '${step_name}' Examples: Artifact access parameter_override={'Args/input_file': '${stage1.artifacts.mydata.url}' } Model access (last model used) parameter_override={'Args/input_file': '${stage1.models.output.-1.url}' } Parameter access parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' } Task ID parameter_override={'Args/input_file': '${stage3.id}' } :param str execution_queue: Optional, the queue to use for executing this specific step. If not provided, the task will be sent to the default execution queue, as defined on the class :param float time_limit: Default None, no time limit. Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed. :param str base_task_project: If base_task_id is not given, use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step. :param str base_task_name: If base_task_id is not given, use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step. :return: True if successful """ # when running remotely do nothing, we will deserialize ourselves when we start if self._task and not self._task.running_locally() and self._task.is_main_task(): return True if name in self._nodes: raise ValueError('Node named \'{}\' already exists in the pipeline dag'.format(name)) if not base_task_id: if not base_task_project or not base_task_name: raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided') base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name) if not base_task: raise ValueError('Could not find base_task_project={} base_task_name={}'.format( base_task_project, base_task_name)) base_task_id = base_task.id self._nodes[name] = self.Node( name=name, base_task_id=base_task_id, parents=parents or [], queue=execution_queue, timeout=time_limit, parameters=parameter_override or {}) return True def start(self, run_remotely=False, step_task_created_callback=None): # type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool """ Start the pipeline controller. If the calling process is stopped, then the controller stops as well. :param bool run_remotely: (default False), If True stop the current process and continue execution on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services' :param Callable step_task_created_callback: Callback function, called when a step (Task) is created and before it is sent for execution. .. code-block:: py def step_created_callback( node, # type: PipelineController.Node, parameters, # type: dict ): pass :return: True, if the controller started. False, if the controller did not start. """ if self._thread: return True # serialize pipeline state pipeline_dag = self._serialize() self._task.connect_configuration(pipeline_dag, name=self._config_section) params = {'continue_pipeline': False, 'default_queue': self._default_execution_queue} self._task.connect(params, name=self._config_section) # deserialize back pipeline state if not params['continue_pipeline']: for k in pipeline_dag: pipeline_dag[k]['executed'] = None self._default_execution_queue = params['default_queue'] self._deserialize(pipeline_dag) if not self._verify(): raise ValueError("Failed verifying pipeline execution graph, " "it has either inaccessible nodes, or contains cycles") self._update_execution_plot() if run_remotely: self._task.execute_remotely(queue_name='services') # we will not get here if we are not running remotely self._start_time = time() self._stop_event = Event() self._experiment_created_cb = step_task_created_callback self._thread = Thread(target=self._daemon) self._thread.daemon = True self._thread.start() return True def stop(self, timeout=None): # type: (Optional[float]) -> () """ Stop the pipeline controller and the optimization thread. :param float timeout: Wait timeout for the optimization thread to exit (minutes). The default is ``None``, indicating do not wait terminate immediately. """ pass def wait(self, timeout=None): # type: (Optional[float]) -> bool """ Wait for the pipeline to finish. .. note:: This method does not stop the pipeline. Call :meth:`stop` to terminate the pipeline. :param float timeout: The timeout to wait for the pipeline to complete (minutes). If ``None``, then wait until we reached the timeout, or pipeline completed. :return: True, if the pipeline finished. False, if the pipeline timed out. """ if not self.is_running(): return True if timeout is not None: timeout *= 60. _thread = self._thread _thread.join(timeout=timeout) if _thread.is_alive(): return False return True def is_running(self): # type: () -> bool """ return True if the pipeline controller is running. :return: A boolean indicating whether the pipeline controller is active (still running) or stopped. """ return self._thread is not None def elapsed(self): # type: () -> float """ Return minutes elapsed from controller stating time stamp. :return: The minutes from controller start time. A negative value means the process has not started yet. """ if self._start_time is None: return -1.0 return (time() - self._start_time) / 60. def get_pipeline_dag(self): # type: () -> Mapping[str, PipelineController.Node] """ Return the pipeline execution graph, each node in the DAG is PipelineController.Node object. Graph itself is a dictionary of Nodes (key based on the Node name), each node holds links to its parent Nodes (identified by their unique names) :return: execution tree, as a nested dictionary Example: { 'stage1' : Node() { name: 'stage1' job: TrainsJob ... }, } """ return self._nodes def get_processed_nodes(self): # type: () -> Sequence[PipelineController.Node] """ Return the a list of the processed pipeline nodes, each entry in the list is PipelineController.Node object. :return: executed (excluding currently executing) nodes list """ return {k: n for k, n in self._nodes.items() if n.executed} def get_running_nodes(self): # type: () -> Sequence[PipelineController.Node] """ Return the a list of the currently running pipeline nodes, each entry in the list is PipelineController.Node object. :return: Currently running nodes list """ return {k: n for k, n in self._nodes.items() if k in self._running_nodes} def _serialize(self): # type: () -> dict """ Store the definition of the pipeline DAG into a dictionary. This dictionary will be used to store the DAG as a configuration on the Task :return: """ dag = {name: dict((k, v) for k, v in node.__dict__.items() if k not in ('job', 'name')) for name, node in self._nodes.items()} return dag def _deserialize(self, dag_dict): # type: (dict) -> () """ Restore the DAG from a dictionary. This will be used to create the DAG from the dict stored on the Task, when running remotely. :return: """ self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()} def _verify(self): # type: () -> bool """ Verify the DAG, (i.e. no cycles and no missing parents) On error raise ValueError with verification details :return: return True iff DAG has no errors """ # verify nodes for node in self._nodes.values(): # raise value error if not verified self._verify_node(node) # check the dag itself if not self._verify_dag(): return False return True def _verify_node(self, node): # type: (Node) -> bool """ Raise ValueError on verification errors :return: Return True iff the specific node is verified """ if not node.base_task_id: raise ValueError("Node '{}', base_task_id is empty".format(node.name)) if not self._default_execution_queue and not node.queue: raise ValueError("Node '{}' missing execution queue, " "no default queue defined and no specific node queue defined".format(node.name)) task = Task.get_task(task_id=node.base_task_id) if not task: raise ValueError("Node '{}', base_task_id={} is invalid".format(node.name, node.base_task_id)) pattern = self._step_ref_pattern for v in node.parameters.values(): for g in pattern.findall(v): self.__verify_step_reference(node, g) return True def _verify_dag(self): # type: () -> bool """ :return: True iff the pipeline dag is fully accessible and contains no cycles """ visited = set() prev_visited = None while prev_visited != visited: prev_visited = copy(visited) for k, node in self._nodes.items(): if k in visited: continue if not all(p in visited for p in node.parents or []): continue visited.add(k) # return False if we did not cover all the nodes return not bool(set(self._nodes.keys()) - visited) def _launch_node(self, node): # type: (Node) -> () """ Launch a single node (create and enqueue a TrainsJob) :param node: Node to launch :return: Return True if a new job was launched """ if node.job or node.executed: return False updated_hyper_parameters = {} for k, v in node.parameters.items(): updated_hyper_parameters[k] = self._parse_step_ref(v) node.job = TrainsJob( base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters, parent=self._task.id) if self._experiment_created_cb: self._experiment_created_cb(node, updated_hyper_parameters) node.job.launch(queue_name=node.queue or self._default_execution_queue) return True def _update_execution_plot(self): # type: () -> () """ Update sankey diagram of the current pipeline """ sankey_node = dict( label=[], color=[], hovertemplate='%{label}', # customdata=[], # hovertemplate='%{label}
Hyper-Parameters:
%{customdata}', ) sankey_link = dict( source=[], target=[], value=[], hovertemplate='%{target.label}', ) visited = [] node_params = [] nodes = list(self._nodes.values()) while nodes: next_nodes = [] for node in nodes: if not all(p in visited for p in node.parents or []): next_nodes.append(node) continue visited.append(node.name) idx = len(visited) - 1 parents = [visited.index(p) for p in node.parents or []] node_params.append(node.job.task_parameter_override if node.job else node.parameters) or {} # sankey_node['label'].append(node.name) # sankey_node['customdata'].append( # '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) sankey_node['label'].append( '{}
'.format(node.name) + '
'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items())) sankey_node['color'].append( ("blue" if not node.job or not node.job.is_failed() else "red") if node.executed else ("green" if node.job else "lightsteelblue")) for p in parents: sankey_link['source'].append(p) sankey_link['target'].append(idx) sankey_link['value'].append(1) nodes = next_nodes # make sure we have no independent (unconnected) nodes for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]: sankey_link['source'].append(i) sankey_link['target'].append(i) sankey_link['value'].append(0.1) fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.03, specs=[[{"type": "table"}], [{"type": "sankey"}], ] ) # noinspection PyUnresolvedReferences fig.add_trace( go.Sankey( node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) ), row=1, col=1 ) # noinspection PyUnresolvedReferences fig.add_trace( go.Table( header=dict( values=["Pipeline Step", "Task ID", "Parameters"], align="left", ), cells=dict( values=[visited, [self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else '') for v in visited], [str(p) for p in node_params]], align="left") ), row=2, col=1 ) # fig = go.Figure(data=[go.Sankey( # node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1))],) self._task.get_logger().report_plotly( title='Pipeline', series='execution flow', iteration=0, figure=fig) def _force_task_configuration_update(self): pipeline_dag = self._serialize() # noinspection PyProtectedMember self._task._set_configuration( name=self._config_section, config_type='dictionary', config_dict=pipeline_dag) def _daemon(self): # type: () -> () """ The main pipeline execution loop. This loop is executed on its own dedicated thread. :return: """ pooling_counter = 0 while self._stop_event: # stop request if pooling_counter and self._stop_event.wait(self._pool_frequency): break pooling_counter += 1 # check the pipeline time limit if self._pipeline_time_limit and (time() - self._start_time) > self._pipeline_time_limit: break # check the state of all current jobs # if no a job ended, continue completed_jobs = [] for j in self._running_nodes: node = self._nodes[j] if not node.job: continue if node.job.is_stopped(): completed_jobs.append(j) node.executed = node.job.task_id() elif node.timeout: started = node.job.task.data.started if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout: node.job.abort() completed_jobs.append(j) node.executed = node.job.task_id() # update running jobs self._running_nodes = [j for j in self._running_nodes if j not in completed_jobs] # nothing changed, we can sleep if not completed_jobs and self._running_nodes: continue # Pull the next jobs in the pipeline, based on the completed list next_nodes = [] for node in self._nodes.values(): # check if already processed. if node.job or node.executed: continue completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []] if all(completed_parents): next_nodes.append(node.name) # update the execution graph for name in next_nodes: if self._launch_node(self._nodes[name]): print('Launching step: {}'.format(name)) print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override)) self._running_nodes.append(name) else: getLogger('trains.automation.controller').error( 'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name])) # update current state (in configuration, so that we could later continue an aborted pipeline) self._force_task_configuration_update() # visualize pipeline state (plot) self._update_execution_plot() # quit if all pipelines nodes are fully executed. if not next_nodes and not self._running_nodes: break # stop all currently running jobs: for node in self._nodes.values(): if node.job and not node.executed: node.job.abort() if self._stop_event: # noinspection PyBroadException try: self._stop_event.set() except Exception: pass def __verify_step_reference(self, node, step_ref_string): # type: (Node, str) -> bool """ Verify the step reference. For example "${step1.parameters.Args/param}" :param Node node: calling reference node (used for logging) :param str step_ref_string: For example "${step1.parameters.Args/param}" :return: True if valid reference """ parts = step_ref_string[2:-1].split('.') v = step_ref_string if len(parts) < 2: raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v)) prev_step = parts[0] input_type = parts[1] if prev_step not in self._nodes: raise ValueError("Node '{}', parameter '{}', step name '{}' is invalid".format(node.name, v, prev_step)) if input_type not in ('artifacts', 'parameters', 'models', 'id'): raise ValueError( "Node {}, parameter '{}', input type '{}' is invalid".format(node.name, v, input_type)) if input_type != 'id' and len(parts) < 3: raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v)) if input_type == 'models': try: model_type = parts[2].lower() except Exception: raise ValueError( "Node '{}', parameter '{}', input type '{}', model_type is missing {}".format( node.name, v, input_type, parts)) if model_type not in ('input', 'output'): raise ValueError( "Node '{}', parameter '{}', input type '{}', " "model_type is invalid (input/output) found {}".format( node.name, v, input_type, model_type)) if len(parts) < 4: raise ValueError( "Node '{}', parameter '{}', input type '{}', model index is missing".format( node.name, v, input_type)) # check casting try: int(parts[3]) except Exception: raise ValueError( "Node '{}', parameter '{}', input type '{}', model index is missing {}".format( node.name, v, input_type, parts)) if len(parts) < 5: raise ValueError( "Node '{}', parameter '{}', input type '{}', model property is missing".format( node.name, v, input_type)) if not hasattr(BaseModel, parts[4]): raise ValueError( "Node '{}', parameter '{}', input type '{}', model property is invalid {}".format( node.name, v, input_type, parts[4])) return True def __parse_step_reference(self, step_ref_string): """ return the adjusted value for "${step...}" :param step_ref_string: reference string of the form ${step_name.type.value}" :return: str with value """ parts = step_ref_string[2:-1].split('.') if len(parts) < 2: raise ValueError("Could not parse reference '{}'".format(step_ref_string)) prev_step = parts[0] input_type = parts[1].lower() if prev_step not in self._nodes or not self._nodes[prev_step].job: raise ValueError("Could not parse reference '{}', step {} could not be found".format( step_ref_string, prev_step)) if input_type not in ('artifacts', 'parameters', 'models', 'id'): raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type)) if input_type != 'id' and len(parts) < 3: raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts)) task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \ else Task.get_task(task_id=self._nodes[prev_step].executed) task.reload() if input_type == 'artifacts': # fix \. to use . in artifacts artifact_path = ('.'.join(parts[2:])).replace('\\.', '\\_dot_\\') artifact_path = artifact_path.split('.') obj = task.artifacts for p in artifact_path: p = p.replace('\\_dot_\\', '.') if isinstance(obj, dict): obj = obj.get(p) elif hasattr(obj, p): obj = getattr(obj, p) else: raise ValueError("Could not locate artifact {} on previous step {}".format( '.'.join(parts[1:]), prev_step)) return str(obj) elif input_type == 'parameters': step_params = task.get_parameters() param_name = '.'.join(parts[2:]) if param_name not in step_params: raise ValueError("Could not locate parameter {} on previous step {}".format( '.'.join(parts[1:]), prev_step)) return step_params.get(param_name) elif input_type == 'models': model_type = parts[2].lower() if model_type not in ('input', 'output'): raise ValueError("Could not locate model {} on previous step {}".format( '.'.join(parts[1:]), prev_step)) try: model_idx = int(parts[3]) model = task.models[model_type][model_idx] except Exception: raise ValueError("Could not locate model {} on previous step {}, index {} is invalid".format( '.'.join(parts[1:]), prev_step, parts[3])) return str(getattr(model, parts[4])) elif input_type == 'id': return task.id return None def _parse_step_ref(self, value): # type: (str) -> Optional[str] """ Return the step reference. For example "${step1.parameters.Args/param}" :param value: string :return: """ # look for all the step references pattern = self._step_ref_pattern updated_value = value for g in pattern.findall(value): # update with actual value new_val = self.__parse_step_reference(g) updated_value = updated_value.replace(g, new_val, 1) return updated_value