diff --git a/trains/automation/controller.py b/trains/automation/controller.py index f2255cc3..3a0aafb9 100644 --- a/trains/automation/controller.py +++ b/trains/automation/controller.py @@ -485,20 +485,19 @@ class PipelineController(object): # make sure we have no independent (unconnected) nodes single_nodes = [] for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]: - # sankey_link['source'].append(i) - # sankey_link['target'].append(i) - # sankey_link['value'].append(0.1) single_nodes.append(i) + # create the sankey graph + dag_flow = go.Sankey( + node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) + ) + + # create the detailed parameter table table_values = [["Pipeline Step", "Task ID", "Parameters"]] table_values += [ [v, self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else ''), str(n)] for v, n in zip(visited, node_params)] - dag_flow = go.Sankey( - node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) - ) - # hack, show single node sankey if single_nodes: singles_flow = go.Scatter( @@ -516,6 +515,7 @@ class PipelineController(object): if len(single_nodes) == len(sankey_node['label']): fig = go.Figure(singles_flow) else: + # both single nodes and DAG fig = make_subplots( rows=2, cols=1, row_heights=[4, 1], @@ -527,17 +527,20 @@ class PipelineController(object): fig.add_trace(dag_flow, row=1, col=1) fig.add_trace(singles_flow, row=2, col=1) else: + # create the sankey plot fig = go.Figure(dag_flow) + # remove background and axis (for scatter) fig.layout.template.layout.plot_bgcolor = None fig.layout.xaxis.visible = False fig.layout.yaxis.visible = False + + # report DAG self._task.get_logger().report_plotly( title='Pipeline', series='Execution Flow', iteration=0, figure=fig) - - table_plot = create_plotly_table(title='Pipeline Details', series='Execution Details', table_plot=table_values) - self._task.get_logger().report_plotly( - title='Pipeline Details', series='Execution Details', iteration=0, figure=table_plot) + # report detailed table + self._task.get_logger().report_table( + title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values) def _force_task_configuration_update(self): pipeline_dag = self._serialize()