Remove plotly usage from controller

This commit is contained in:
allegroai 2020-11-25 10:50:57 +02:00
parent 6064080232
commit 17f9fa512f

View File

@ -5,16 +5,12 @@ from logging import getLogger
from threading import Thread, Event from threading import Thread, Event
from time import time from time import time
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from attr import attrib, attrs from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union from typing import Sequence, Optional, Mapping, Callable, Any, Union
from ..task import Task from ..task import Task
from ..automation import TrainsJob from ..automation import TrainsJob
from ..model import BaseModel from ..model import BaseModel
from ..utilities.plotly_reporter import create_plotly_table
class PipelineController(object): class PipelineController(object):
@ -488,8 +484,12 @@ class PipelineController(object):
single_nodes.append(i) single_nodes.append(i)
# create the sankey graph # create the sankey graph
dag_flow = go.Sankey( dag_flow = dict(
node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) link=sankey_link,
node=sankey_node,
textfont=dict(color='rgba(0,0,0,0)', size=1),
type='sankey',
orientation='h'
) )
# create the detailed parameter table # create the detailed parameter table
@ -500,7 +500,7 @@ class PipelineController(object):
# hack, show single node sankey # hack, show single node sankey
if single_nodes: if single_nodes:
singles_flow = go.Scatter( singles_flow = dict(
x=list(range(len(single_nodes))), y=[1] * len(single_nodes), x=list(range(len(single_nodes))), y=[1] * len(single_nodes),
text=[v for i, v in enumerate(sankey_node['label']) if i in single_nodes], text=[v for i, v in enumerate(sankey_node['label']) if i in single_nodes],
mode='markers', mode='markers',
@ -510,30 +510,23 @@ class PipelineController(object):
size=[40] * len(single_nodes), size=[40] * len(single_nodes),
), ),
showlegend=False, showlegend=False,
type='scatter',
) )
# only single nodes # only single nodes
if len(single_nodes) == len(sankey_node['label']): if len(single_nodes) == len(sankey_node['label']):
fig = go.Figure(singles_flow) fig = dict(data=[singles_flow], layout={
'hovermode': 'closest', 'xaxis': {'visible': False}, 'yaxis': {'visible': False}})
else: else:
# both single nodes and DAG dag_flow['domain'] = {'x': [0.0, 1.0], 'y': [0.2, 1.0]}
fig = make_subplots( fig = dict(data=[dag_flow, singles_flow],
rows=2, cols=1, layout={'autosize': True,
row_heights=[4, 1], 'hovermode': 'closest',
shared_xaxes=False, 'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0], 'visible': False},
vertical_spacing=0.03, 'yaxis': {'anchor': 'x', 'domain': [0.0, 0.15], 'visible': False}
specs=[[{"type": "sankey"}], })
[{"type": "xy"}]]
)
fig.add_trace(dag_flow, row=1, col=1)
fig.add_trace(singles_flow, row=2, col=1)
else: else:
# create the sankey plot # create the sankey plot
fig = go.Figure(dag_flow) fig = dict(data=[dag_flow], layout={'xaxis': {'visible': False}, 'yaxis': {'visible': False}})
# remove background and axis (for scatter)
fig.layout.template.layout.plot_bgcolor = None
fig.layout.xaxis.visible = False
fig.layout.yaxis.visible = False
# report DAG # report DAG
self._task.get_logger().report_plotly( self._task.get_logger().report_plotly(