From adbe02b62bd4699a0b34ac9ec5d164475cd9bdd1 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Fri, 25 Oct 2019 15:12:01 +0300
Subject: [PATCH] Optimize scatter 2d plot storage
---
trains/utilities/plotly_reporter.py | 35 ++++++++++++++++-------------
1 file changed, 20 insertions(+), 15 deletions(-)
diff --git a/trains/utilities/plotly_reporter.py b/trains/utilities/plotly_reporter.py
index d483f1aa..e27a61e2 100644
--- a/trains/utilities/plotly_reporter.py
+++ b/trains/utilities/plotly_reporter.py
@@ -61,7 +61,7 @@ class SeriesInfo(object):
)
-def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=False, comment=None):
+def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=False, comment=None, MAX_SIZE=None):
plotly_obj = _plotly_scatter_layout_dict(
title=title if not comment else (title + '
' + comment + ''),
xaxis_title=xtitle,
@@ -72,14 +72,14 @@ def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=
plotly_obj["layout"]["xaxis"]["autorange"] = "reversed"
# check maximum size of data
- _MAX_SIZE = 800000
+ MAX_SIZE = MAX_SIZE or 800000
series_sizes = [s.data.size for s in series]
total_size = sum(series_sizes)
- if total_size > _MAX_SIZE:
+ if total_size > MAX_SIZE:
# we need to downscale
- base_size = _MAX_SIZE / len(series_sizes)
+ base_size = MAX_SIZE / len(series_sizes)
baseused_size = sum([min(s, base_size) for s in series_sizes])
- leftover = _MAX_SIZE - baseused_size
+ leftover = MAX_SIZE - baseused_size
for s in series:
# if we need to down-sample, use low-pass average filter and sampling
if s.data.size >= base_size:
@@ -93,6 +93,7 @@ def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=
# decide on number of points between mean and max
s_max = np.max(np.abs(s.data), axis=0)
+ s_max = np.maximum(s_max, s_max * 0 + 0.01)
digits = np.maximum(np.array([1, 1]), np.array([6, 6]) - np.floor(np.abs(np.log10(s_max))))
s.data[:, 0] = np.round(s.data[:, 0] * (10 ** digits[0])) / (10 ** digits[0])
s.data[:, 1] = np.round(s.data[:, 1] * (10 ** digits[1])) / (10 ** digits[1])
@@ -128,16 +129,20 @@ def create_2d_scatter_series(np_row_wise, title="Scatter", series_name="Series",
assert np_row_wise.ndim == 2, "Expected a 2D numpy array"
assert np_row_wise.shape[1] == 2, "Expected two columns X/Y e.g. [(x0,y0), (x1,y1) ...]"
- this_scatter_data = {
- "name": series_name,
- "x": np_row_wise[:, 0].tolist(),
- "y": np_row_wise[:, 1].tolist(),
- "mode": mode,
- "text": labels,
- "type": "scatter"
- }
- plotly_obj["data"].append(this_scatter_data)
- return plotly_obj
+ # this_scatter_data = {
+ # "name": series_name,
+ # "x": np_row_wise[:, 0].tolist(),
+ # "y": np_row_wise[:, 1].tolist(),
+ # "mode": mode,
+ # "text": labels,
+ # "type": "scatter"
+ # }
+ # plotly_obj["data"].append(this_scatter_data)
+ # return plotly_obj
+ series = SeriesInfo(name=series_name, data=np_row_wise, labels=labels)
+
+ return create_line_plot(title=title, series=[series], xtitle=xtitle, ytitle=ytitle, mode=mode,
+ comment=comment, MAX_SIZE=100000)
def create_3d_scatter_series(np_row_wise, title="Scatter", series_name="Series", xtitle="x", ytitle="y", ztitle="z",