From 17f9fa512fefccf29853b9e5008ee1d4b78953b9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 25 Nov 2020 10:50:57 +0200 Subject: [PATCH] Remove plotly usage from controller --- trains/automation/controller.py | 63 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/trains/automation/controller.py b/trains/automation/controller.py index 3a0aafb9..0625b2b2 100644 --- a/trains/automation/controller.py +++ b/trains/automation/controller.py @@ -5,16 +5,12 @@ from logging import getLogger from threading import Thread, Event from time import time -from plotly import graph_objects as go -from plotly.subplots import make_subplots - from attr import attrib, attrs from typing import Sequence, Optional, Mapping, Callable, Any, Union from ..task import Task from ..automation import TrainsJob from ..model import BaseModel -from ..utilities.plotly_reporter import create_plotly_table class PipelineController(object): @@ -488,8 +484,12 @@ class PipelineController(object): single_nodes.append(i) # create the sankey graph - dag_flow = go.Sankey( - node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1) + dag_flow = dict( + link=sankey_link, + node=sankey_node, + textfont=dict(color='rgba(0,0,0,0)', size=1), + type='sankey', + orientation='h' ) # create the detailed parameter table @@ -500,40 +500,33 @@ class PipelineController(object): # hack, show single node sankey if single_nodes: - singles_flow = go.Scatter( - x=list(range(len(single_nodes))), y=[1]*len(single_nodes), - text=[v for i, v in enumerate(sankey_node['label']) if i in single_nodes], - mode='markers', - hovertemplate="%{text}", - marker=dict( - color=[v for i, v in enumerate(sankey_node['color']) if i in single_nodes], - size=[40]*len(single_nodes), - ), - showlegend=False, - ) + singles_flow = dict( + x=list(range(len(single_nodes))), y=[1] * len(single_nodes), + text=[v for i, v in enumerate(sankey_node['label']) if i in single_nodes], + mode='markers', + hovertemplate="%{text}", + marker=dict( + color=[v for i, v in enumerate(sankey_node['color']) if i in single_nodes], + size=[40] * len(single_nodes), + ), + showlegend=False, + type='scatter', + ) # only single nodes if len(single_nodes) == len(sankey_node['label']): - fig = go.Figure(singles_flow) + fig = dict(data=[singles_flow], layout={ + 'hovermode': 'closest', 'xaxis': {'visible': False}, 'yaxis': {'visible': False}}) else: - # both single nodes and DAG - fig = make_subplots( - rows=2, cols=1, - row_heights=[4, 1], - shared_xaxes=False, - vertical_spacing=0.03, - specs=[[{"type": "sankey"}], - [{"type": "xy"}]] - ) - fig.add_trace(dag_flow, row=1, col=1) - fig.add_trace(singles_flow, row=2, col=1) + dag_flow['domain'] = {'x': [0.0, 1.0], 'y': [0.2, 1.0]} + fig = dict(data=[dag_flow, singles_flow], + layout={'autosize': True, + 'hovermode': 'closest', + 'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0], 'visible': False}, + 'yaxis': {'anchor': 'x', 'domain': [0.0, 0.15], 'visible': False} + }) else: # create the sankey plot - fig = go.Figure(dag_flow) - - # remove background and axis (for scatter) - fig.layout.template.layout.plot_bgcolor = None - fig.layout.xaxis.visible = False - fig.layout.yaxis.visible = False + fig = dict(data=[dag_flow], layout={'xaxis': {'visible': False}, 'yaxis': {'visible': False}}) # report DAG self._task.get_logger().report_plotly(