Fix NaN in plotly plots (matplotlib conversion)

Add sdk.metrics.plot_max_num_digits to limit the number of digits in a plot (reduce plot sizes)
This commit is contained in:
allegroai 2020-09-05 16:28:24 +03:00
parent c30b6d27e8
commit fae11edf1b
3 changed files with 43 additions and 5 deletions

View File

@ -36,6 +36,9 @@ sdk {
# X images are stored in the upload destination for each matplotlib plot title. # X images are stored in the upload destination for each matplotlib plot title.
matplotlib_untitled_history_size: 100 matplotlib_untitled_history_size: 100
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
# plot_max_num_digits: 5
# Settings for generated debug images # Settings for generated debug images
images { images {
format: JPEG format: JPEG

View File

@ -1,4 +1,5 @@
import json import json
import math
try: try:
from collections.abc import Iterable from collections.abc import Iterable
@ -17,6 +18,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
create_image_plot, create_plotly_table create_image_plot, create_plotly_table
from ...utilities.py3_interop import AbstractContextManager from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
from ...config import config
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin): class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
@ -66,6 +68,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._storage_uri = value self._storage_uri = value
storage_uri = property(None, _set_storage_uri) storage_uri = property(None, _set_storage_uri)
max_float_num_digits = config.get('metrics.plot_max_num_digits', None)
@property @property
def flush_threshold(self): def flush_threshold(self):
@ -164,7 +167,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter) iter=iter)
self._report(ev) self._report(ev)
def report_plot(self, title, series, plot, iter): def report_plot(self, title, series, plot, iter, round_digits=None):
""" """
Report a Plotly chart Report a Plotly chart
:param title: Title (AKA metric) :param title: Title (AKA metric)
@ -174,8 +177,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/) :param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/)
:type plot: str or dict :type plot: str or dict
:param iter: Iteration number :param iter: Iteration number
:param round_digits: number of digits after the dot to leave
:type value: int :type value: int
""" """
def floatstr(o):
if o != o:
return 'nan'
elif o == math.inf:
return 'inf'
elif o == -math.inf:
return '-inf'
return round(o, ndigits=round_digits) if round_digits is not None else o
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# Special json encoder for numpy types # Special json encoder for numpy types
@ -183,18 +196,36 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
if isinstance(obj, (np.integer, np.int64)): if isinstance(obj, (np.integer, np.int64)):
return int(obj) return int(obj)
elif isinstance(obj, np.floating): elif isinstance(obj, np.floating):
return float(obj) return float(round(obj, ndigits=round_digits) if round_digits is not None else obj)
elif isinstance(obj, np.ndarray): elif isinstance(obj, np.ndarray):
return obj.tolist() return obj.round(round_digits).tolist() if round_digits is not None else obj.tolist()
except Exception: except Exception:
default = None default = None
if round_digits is None:
round_digits = self.max_float_num_digits
if round_digits is False:
round_digits = None
if isinstance(plot, dict): if isinstance(plot, dict):
if 'data' in plot:
for d in plot['data']:
if not isinstance(d, dict):
continue
for k, v in d.items():
if isinstance(v, list):
d[k] = list(floatstr(s) if isinstance(s, float) else s for s in v)
elif isinstance(v, tuple):
d[k] = tuple(floatstr(s) if isinstance(s, float) else s for s in v)
elif isinstance(v, float):
d[k] = floatstr(v)
plot = json.dumps(plot, default=default) plot = json.dumps(plot, default=default)
elif not isinstance(plot, six.string_types): elif not isinstance(plot, six.string_types):
raise ValueError('Plot should be a string or a dict') raise ValueError('Plot should be a string or a dict')
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), plot_str=plot, ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series),
iter=iter) plot_str=plot, iter=iter)
self._report(ev) self._report(ev)
def report_image(self, title, series, src, iter): def report_image(self, title, series, src, iter):
@ -367,6 +398,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
series=self._normalize_name(series), series=self._normalize_name(series),
plot=table_output, plot=table_output,
iter=iteration, iter=iteration,
round_digits=False,
) )
def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False,

View File

@ -26,6 +26,9 @@
# X images are stored in the upload destination for each matplotlib plot title. # X images are stored in the upload destination for each matplotlib plot title.
matplotlib_untitled_history_size: 100 matplotlib_untitled_history_size: 100
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
# plot_max_num_digits: 5
# Settings for generated debug images # Settings for generated debug images
images { images {
format: JPEG format: JPEG