Cleanup pipeline DAG plot generation

This commit is contained in:
allegroai 2020-11-01 16:01:47 +02:00
parent b8ebbaa3b8
commit 2847cce18d

View File

@ -485,20 +485,19 @@ class PipelineController(object):
# make sure we have no independent (unconnected) nodes
single_nodes = []
for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]:
# sankey_link['source'].append(i)
# sankey_link['target'].append(i)
# sankey_link['value'].append(0.1)
single_nodes.append(i)
# create the sankey graph
dag_flow = go.Sankey(
node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
)
# create the detailed parameter table
table_values = [["Pipeline Step", "Task ID", "Parameters"]]
table_values += [
[v, self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else ''), str(n)]
for v, n in zip(visited, node_params)]
dag_flow = go.Sankey(
node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
)
# hack, show single node sankey
if single_nodes:
singles_flow = go.Scatter(
@ -516,6 +515,7 @@ class PipelineController(object):
if len(single_nodes) == len(sankey_node['label']):
fig = go.Figure(singles_flow)
else:
# both single nodes and DAG
fig = make_subplots(
rows=2, cols=1,
row_heights=[4, 1],
@ -527,17 +527,20 @@ class PipelineController(object):
fig.add_trace(dag_flow, row=1, col=1)
fig.add_trace(singles_flow, row=2, col=1)
else:
# create the sankey plot
fig = go.Figure(dag_flow)
# remove background and axis (for scatter)
fig.layout.template.layout.plot_bgcolor = None
fig.layout.xaxis.visible = False
fig.layout.yaxis.visible = False
# report DAG
self._task.get_logger().report_plotly(
title='Pipeline', series='Execution Flow', iteration=0, figure=fig)
table_plot = create_plotly_table(title='Pipeline Details', series='Execution Details', table_plot=table_values)
self._task.get_logger().report_plotly(
title='Pipeline Details', series='Execution Details', iteration=0, figure=table_plot)
# report detailed table
self._task.get_logger().report_table(
title='Pipeline Details', series='Execution Details', iteration=0, table_plot=table_values)
def _force_task_configuration_update(self):
pipeline_dag = self._serialize()