Add support to split DAG and Table in pipeline DAG plot. Pipeline DAG single nodes are now round circles below the DAG graph. Fix hover text size limit

This commit is contained in:
allegroai 2020-11-01 16:00:51 +02:00
parent 665bfc5ca8
commit 236e6adcfd

View File

@ -11,9 +11,10 @@ from plotly.subplots import make_subplots
from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union
from trains import Task
from trains.automation import TrainsJob
from trains.model import BaseModel
from ..task import Task
from ..automation import TrainsJob
from ..model import BaseModel
from ..utilities.plotly_reporter import create_plotly_table
class PipelineController(object):
@ -447,7 +448,8 @@ class PipelineController(object):
source=[],
target=[],
value=[],
hovertemplate='%{target.label}<extra></extra>',
# hovertemplate='%{target.label}<extra></extra>',
hovertemplate='<extra></extra>',
)
visited = []
node_params = []
@ -467,7 +469,8 @@ class PipelineController(object):
# '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
sankey_node['label'].append(
'{}<br />'.format(node.name) +
'<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
'<br />'.join('{}: {}'.format(k, v if len(str(v)) < 24 else (str(v)[:24]+' ...'))
for k, v in (node.parameters or {}).items()))
sankey_node['color'].append(
("blue" if not node.job or not node.job.is_failed() else "red")
if node.executed is not None else ("green" if node.job else "lightsteelblue"))
@ -480,46 +483,61 @@ class PipelineController(object):
nodes = next_nodes
# make sure we have no independent (unconnected) nodes
single_nodes = []
for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]:
sankey_link['source'].append(i)
sankey_link['target'].append(i)
sankey_link['value'].append(0.1)
# sankey_link['source'].append(i)
# sankey_link['target'].append(i)
# sankey_link['value'].append(0.1)
single_nodes.append(i)
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
specs=[[{"type": "table"}],
[{"type": "sankey"}], ]
)
# noinspection PyUnresolvedReferences
fig.add_trace(
go.Sankey(
node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
),
row=1, col=1
)
# noinspection PyUnresolvedReferences
fig.add_trace(
go.Table(
header=dict(
values=["Pipeline Step", "Task ID", "Parameters"],
align="left",
),
cells=dict(
values=[visited,
[self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else '')
for v in visited],
[str(p) for p in node_params]],
align="left")
),
row=2, col=1
table_values = [["Pipeline Step", "Task ID", "Parameters"]]
table_values += [
[v, self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else ''), str(n)]
for v, n in zip(visited, node_params)]
dag_flow = go.Sankey(
node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
)
# fig = go.Figure(data=[go.Sankey(
# node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1))],)
# hack, show single node sankey
if single_nodes:
singles_flow = go.Scatter(
x=list(range(len(single_nodes))), y=[1]*len(single_nodes),
text=[v for i, v in enumerate(sankey_node['label']) if i in single_nodes],
mode='markers',
hovertemplate="%{text}<extra></extra>",
marker=dict(
color=[v for i, v in enumerate(sankey_node['color']) if i in single_nodes],
size=[40]*len(single_nodes),
),
showlegend=False,
)
# only single nodes
if len(single_nodes) == len(sankey_node['label']):
fig = go.Figure(singles_flow)
else:
fig = make_subplots(
rows=2, cols=1,
row_heights=[4, 1],
shared_xaxes=False,
vertical_spacing=0.03,
specs=[[{"type": "sankey"}],
[{"type": "xy"}]]
)
fig.add_trace(dag_flow, row=1, col=1)
fig.add_trace(singles_flow, row=2, col=1)
else:
fig = go.Figure(dag_flow)
fig.layout.template.layout.plot_bgcolor = None
fig.layout.xaxis.visible = False
fig.layout.yaxis.visible = False
self._task.get_logger().report_plotly(
title='Pipeline', series='execution flow', iteration=0, figure=fig)
title='Pipeline', series='Execution Flow', iteration=0, figure=fig)
table_plot = create_plotly_table(title='Pipeline Details', series='Execution Details', table_plot=table_values)
self._task.get_logger().report_plotly(
title='Pipeline Details', series='Execution Details', iteration=0, figure=table_plot)
def _force_task_configuration_update(self):
pipeline_dag = self._serialize()