mirror of
https://github.com/clearml/clearml
synced 2025-04-08 06:34:37 +00:00
Fix NaN in plotly plots (matplotlib conversion)
Add sdk.metrics.plot_max_num_digits to limit the number of digits in a plot (reduce plot sizes)
This commit is contained in:
parent
c30b6d27e8
commit
fae11edf1b
@ -36,6 +36,9 @@ sdk {
|
|||||||
# X images are stored in the upload destination for each matplotlib plot title.
|
# X images are stored in the upload destination for each matplotlib plot title.
|
||||||
matplotlib_untitled_history_size: 100
|
matplotlib_untitled_history_size: 100
|
||||||
|
|
||||||
|
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||||
|
# plot_max_num_digits: 5
|
||||||
|
|
||||||
# Settings for generated debug images
|
# Settings for generated debug images
|
||||||
images {
|
images {
|
||||||
format: JPEG
|
format: JPEG
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@ -17,6 +18,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
|
|||||||
create_image_plot, create_plotly_table
|
create_image_plot, create_plotly_table
|
||||||
from ...utilities.py3_interop import AbstractContextManager
|
from ...utilities.py3_interop import AbstractContextManager
|
||||||
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
|
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
|
||||||
|
from ...config import config
|
||||||
|
|
||||||
|
|
||||||
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
|
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
|
||||||
@ -66,6 +68,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
self._storage_uri = value
|
self._storage_uri = value
|
||||||
|
|
||||||
storage_uri = property(None, _set_storage_uri)
|
storage_uri = property(None, _set_storage_uri)
|
||||||
|
max_float_num_digits = config.get('metrics.plot_max_num_digits', None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flush_threshold(self):
|
def flush_threshold(self):
|
||||||
@ -164,7 +167,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
iter=iter)
|
iter=iter)
|
||||||
self._report(ev)
|
self._report(ev)
|
||||||
|
|
||||||
def report_plot(self, title, series, plot, iter):
|
def report_plot(self, title, series, plot, iter, round_digits=None):
|
||||||
"""
|
"""
|
||||||
Report a Plotly chart
|
Report a Plotly chart
|
||||||
:param title: Title (AKA metric)
|
:param title: Title (AKA metric)
|
||||||
@ -174,8 +177,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
:param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/)
|
:param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/)
|
||||||
:type plot: str or dict
|
:type plot: str or dict
|
||||||
:param iter: Iteration number
|
:param iter: Iteration number
|
||||||
|
:param round_digits: number of digits after the dot to leave
|
||||||
:type value: int
|
:type value: int
|
||||||
"""
|
"""
|
||||||
|
def floatstr(o):
|
||||||
|
if o != o:
|
||||||
|
return 'nan'
|
||||||
|
elif o == math.inf:
|
||||||
|
return 'inf'
|
||||||
|
elif o == -math.inf:
|
||||||
|
return '-inf'
|
||||||
|
return round(o, ndigits=round_digits) if round_digits is not None else o
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# Special json encoder for numpy types
|
# Special json encoder for numpy types
|
||||||
@ -183,18 +196,36 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
if isinstance(obj, (np.integer, np.int64)):
|
if isinstance(obj, (np.integer, np.int64)):
|
||||||
return int(obj)
|
return int(obj)
|
||||||
elif isinstance(obj, np.floating):
|
elif isinstance(obj, np.floating):
|
||||||
return float(obj)
|
return float(round(obj, ndigits=round_digits) if round_digits is not None else obj)
|
||||||
elif isinstance(obj, np.ndarray):
|
elif isinstance(obj, np.ndarray):
|
||||||
return obj.tolist()
|
return obj.round(round_digits).tolist() if round_digits is not None else obj.tolist()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
default = None
|
default = None
|
||||||
|
|
||||||
|
if round_digits is None:
|
||||||
|
round_digits = self.max_float_num_digits
|
||||||
|
|
||||||
|
if round_digits is False:
|
||||||
|
round_digits = None
|
||||||
|
|
||||||
if isinstance(plot, dict):
|
if isinstance(plot, dict):
|
||||||
|
if 'data' in plot:
|
||||||
|
for d in plot['data']:
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
continue
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
d[k] = list(floatstr(s) if isinstance(s, float) else s for s in v)
|
||||||
|
elif isinstance(v, tuple):
|
||||||
|
d[k] = tuple(floatstr(s) if isinstance(s, float) else s for s in v)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
d[k] = floatstr(v)
|
||||||
plot = json.dumps(plot, default=default)
|
plot = json.dumps(plot, default=default)
|
||||||
elif not isinstance(plot, six.string_types):
|
elif not isinstance(plot, six.string_types):
|
||||||
raise ValueError('Plot should be a string or a dict')
|
raise ValueError('Plot should be a string or a dict')
|
||||||
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), plot_str=plot,
|
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series),
|
||||||
iter=iter)
|
plot_str=plot, iter=iter)
|
||||||
self._report(ev)
|
self._report(ev)
|
||||||
|
|
||||||
def report_image(self, title, series, src, iter):
|
def report_image(self, title, series, src, iter):
|
||||||
@ -367,6 +398,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
series=self._normalize_name(series),
|
series=self._normalize_name(series),
|
||||||
plot=table_output,
|
plot=table_output,
|
||||||
iter=iteration,
|
iter=iteration,
|
||||||
|
round_digits=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False,
|
def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False,
|
||||||
|
@ -26,6 +26,9 @@
|
|||||||
# X images are stored in the upload destination for each matplotlib plot title.
|
# X images are stored in the upload destination for each matplotlib plot title.
|
||||||
matplotlib_untitled_history_size: 100
|
matplotlib_untitled_history_size: 100
|
||||||
|
|
||||||
|
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||||
|
# plot_max_num_digits: 5
|
||||||
|
|
||||||
# Settings for generated debug images
|
# Settings for generated debug images
|
||||||
images {
|
images {
|
||||||
format: JPEG
|
format: JPEG
|
||||||
|
Loading…
Reference in New Issue
Block a user