From fae11edf1b2d736871f08e2bd53776a9a731a1b7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 5 Sep 2020 16:28:24 +0300 Subject: [PATCH] Fix NaN in plotly plots (matplotlib conversion) Add sdk.metrics.plot_max_num_digits to limit the number of digits in a plot (reduce plot sizes) --- docs/trains.conf | 3 ++ trains/backend_interface/metrics/reporter.py | 42 +++++++++++++++++--- trains/config/default/sdk.conf | 3 ++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index 67ceb04e..ba83a3ae 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -36,6 +36,9 @@ sdk { # X images are stored in the upload destination for each matplotlib plot title. matplotlib_untitled_history_size: 100 + # Limit the number of digits after the dot in plot reporting (reducing plot report size) + # plot_max_num_digits: 5 + # Settings for generated debug images images { format: JPEG diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 09774c4c..9b3d26ce 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -1,4 +1,5 @@ import json +import math try: from collections.abc import Iterable @@ -17,6 +18,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_ create_image_plot, create_plotly_table from ...utilities.py3_interop import AbstractContextManager from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent +from ...config import config class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin): @@ -66,6 +68,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan self._storage_uri = value storage_uri = property(None, _set_storage_uri) + max_float_num_digits = config.get('metrics.plot_max_num_digits', None) @property def flush_threshold(self): @@ -164,7 +167,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan iter=iter) self._report(ev) - def report_plot(self, title, series, plot, iter): + def report_plot(self, title, series, plot, iter, round_digits=None): """ Report a Plotly chart :param title: Title (AKA metric) @@ -174,8 +177,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/) :type plot: str or dict :param iter: Iteration number + :param round_digits: number of digits after the dot to leave :type value: int """ + def floatstr(o): + if o != o: + return 'nan' + elif o == math.inf: + return 'inf' + elif o == -math.inf: + return '-inf' + return round(o, ndigits=round_digits) if round_digits is not None else o + # noinspection PyBroadException try: # Special json encoder for numpy types @@ -183,18 +196,36 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan if isinstance(obj, (np.integer, np.int64)): return int(obj) elif isinstance(obj, np.floating): - return float(obj) + return float(round(obj, ndigits=round_digits) if round_digits is not None else obj) elif isinstance(obj, np.ndarray): - return obj.tolist() + return obj.round(round_digits).tolist() if round_digits is not None else obj.tolist() + except Exception: default = None + if round_digits is None: + round_digits = self.max_float_num_digits + + if round_digits is False: + round_digits = None + if isinstance(plot, dict): + if 'data' in plot: + for d in plot['data']: + if not isinstance(d, dict): + continue + for k, v in d.items(): + if isinstance(v, list): + d[k] = list(floatstr(s) if isinstance(s, float) else s for s in v) + elif isinstance(v, tuple): + d[k] = tuple(floatstr(s) if isinstance(s, float) else s for s in v) + elif isinstance(v, float): + d[k] = floatstr(v) plot = json.dumps(plot, default=default) elif not isinstance(plot, six.string_types): raise ValueError('Plot should be a string or a dict') - ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), plot_str=plot, - iter=iter) + ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), + plot_str=plot, iter=iter) self._report(ev) def report_image(self, title, series, src, iter): @@ -367,6 +398,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan series=self._normalize_name(series), plot=table_output, iter=iteration, + round_digits=False, ) def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, diff --git a/trains/config/default/sdk.conf b/trains/config/default/sdk.conf index 17255551..a47adaaf 100644 --- a/trains/config/default/sdk.conf +++ b/trains/config/default/sdk.conf @@ -26,6 +26,9 @@ # X images are stored in the upload destination for each matplotlib plot title. matplotlib_untitled_history_size: 100 + # Limit the number of digits after the dot in plot reporting (reducing plot report size) + # plot_max_num_digits: 5 + # Settings for generated debug images images { format: JPEG