mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix NaN in plotly plots (matplotlib conversion)
Add sdk.metrics.plot_max_num_digits to limit the number of digits in a plot (reduce plot sizes)
This commit is contained in:
		
							parent
							
								
									c30b6d27e8
								
							
						
					
					
						commit
						fae11edf1b
					
				| @ -36,6 +36,9 @@ sdk { | |||||||
|         # X images are stored in the upload destination for each matplotlib plot title. |         # X images are stored in the upload destination for each matplotlib plot title. | ||||||
|         matplotlib_untitled_history_size: 100 |         matplotlib_untitled_history_size: 100 | ||||||
| 
 | 
 | ||||||
|  |         # Limit the number of digits after the dot in plot reporting (reducing plot report size) | ||||||
|  |         # plot_max_num_digits: 5 | ||||||
|  | 
 | ||||||
|         # Settings for generated debug images |         # Settings for generated debug images | ||||||
|         images { |         images { | ||||||
|             format: JPEG |             format: JPEG | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| import json | import json | ||||||
|  | import math | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     from collections.abc import Iterable |     from collections.abc import Iterable | ||||||
| @ -17,6 +18,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_ | |||||||
|     create_image_plot, create_plotly_table |     create_image_plot, create_plotly_table | ||||||
| from ...utilities.py3_interop import AbstractContextManager | from ...utilities.py3_interop import AbstractContextManager | ||||||
| from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent | from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent | ||||||
|  | from ...config import config | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin): | class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin): | ||||||
| @ -66,6 +68,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan | |||||||
|         self._storage_uri = value |         self._storage_uri = value | ||||||
| 
 | 
 | ||||||
|     storage_uri = property(None, _set_storage_uri) |     storage_uri = property(None, _set_storage_uri) | ||||||
|  |     max_float_num_digits = config.get('metrics.plot_max_num_digits', None) | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def flush_threshold(self): |     def flush_threshold(self): | ||||||
| @ -164,7 +167,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan | |||||||
|                          iter=iter) |                          iter=iter) | ||||||
|         self._report(ev) |         self._report(ev) | ||||||
| 
 | 
 | ||||||
|     def report_plot(self, title, series, plot, iter): |     def report_plot(self, title, series, plot, iter, round_digits=None): | ||||||
|         """ |         """ | ||||||
|         Report a Plotly chart |         Report a Plotly chart | ||||||
|         :param title: Title (AKA metric) |         :param title: Title (AKA metric) | ||||||
| @ -174,8 +177,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan | |||||||
|         :param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/) |         :param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/) | ||||||
|         :type plot: str or dict |         :type plot: str or dict | ||||||
|         :param iter: Iteration number |         :param iter: Iteration number | ||||||
|  |         :param round_digits: number of digits after the dot to leave | ||||||
|         :type value: int |         :type value: int | ||||||
|         """ |         """ | ||||||
|  |         def floatstr(o): | ||||||
|  |             if o != o: | ||||||
|  |                 return 'nan' | ||||||
|  |             elif o == math.inf: | ||||||
|  |                 return 'inf' | ||||||
|  |             elif o == -math.inf: | ||||||
|  |                 return '-inf' | ||||||
|  |             return round(o, ndigits=round_digits) if round_digits is not None else o | ||||||
|  | 
 | ||||||
|         # noinspection PyBroadException |         # noinspection PyBroadException | ||||||
|         try: |         try: | ||||||
|             # Special json encoder for numpy types |             # Special json encoder for numpy types | ||||||
| @ -183,18 +196,36 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan | |||||||
|                 if isinstance(obj, (np.integer, np.int64)): |                 if isinstance(obj, (np.integer, np.int64)): | ||||||
|                     return int(obj) |                     return int(obj) | ||||||
|                 elif isinstance(obj, np.floating): |                 elif isinstance(obj, np.floating): | ||||||
|                     return float(obj) |                     return float(round(obj, ndigits=round_digits) if round_digits is not None else obj) | ||||||
|                 elif isinstance(obj, np.ndarray): |                 elif isinstance(obj, np.ndarray): | ||||||
|                     return obj.tolist() |                     return obj.round(round_digits).tolist() if round_digits is not None else obj.tolist() | ||||||
|  | 
 | ||||||
|         except Exception: |         except Exception: | ||||||
|             default = None |             default = None | ||||||
| 
 | 
 | ||||||
|  |         if round_digits is None: | ||||||
|  |             round_digits = self.max_float_num_digits | ||||||
|  | 
 | ||||||
|  |         if round_digits is False: | ||||||
|  |             round_digits = None | ||||||
|  | 
 | ||||||
|         if isinstance(plot, dict): |         if isinstance(plot, dict): | ||||||
|  |             if 'data' in plot: | ||||||
|  |                 for d in plot['data']: | ||||||
|  |                     if not isinstance(d, dict): | ||||||
|  |                         continue | ||||||
|  |                     for k, v in d.items(): | ||||||
|  |                         if isinstance(v, list): | ||||||
|  |                             d[k] = list(floatstr(s) if isinstance(s, float) else s for s in v) | ||||||
|  |                         elif isinstance(v, tuple): | ||||||
|  |                             d[k] = tuple(floatstr(s) if isinstance(s, float) else s for s in v) | ||||||
|  |                         elif isinstance(v, float): | ||||||
|  |                             d[k] = floatstr(v) | ||||||
|             plot = json.dumps(plot, default=default) |             plot = json.dumps(plot, default=default) | ||||||
|         elif not isinstance(plot, six.string_types): |         elif not isinstance(plot, six.string_types): | ||||||
|             raise ValueError('Plot should be a string or a dict') |             raise ValueError('Plot should be a string or a dict') | ||||||
|         ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), plot_str=plot, |         ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), | ||||||
|                        iter=iter) |                        plot_str=plot, iter=iter) | ||||||
|         self._report(ev) |         self._report(ev) | ||||||
| 
 | 
 | ||||||
|     def report_image(self, title, series, src, iter): |     def report_image(self, title, series, src, iter): | ||||||
| @ -367,6 +398,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan | |||||||
|             series=self._normalize_name(series), |             series=self._normalize_name(series), | ||||||
|             plot=table_output, |             plot=table_output, | ||||||
|             iter=iteration, |             iter=iteration, | ||||||
|  |             round_digits=False, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, |     def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, | ||||||
|  | |||||||
| @ -26,6 +26,9 @@ | |||||||
|         # X images are stored in the upload destination for each matplotlib plot title. |         # X images are stored in the upload destination for each matplotlib plot title. | ||||||
|         matplotlib_untitled_history_size: 100 |         matplotlib_untitled_history_size: 100 | ||||||
| 
 | 
 | ||||||
|  |         # Limit the number of digits after the dot in plot reporting (reducing plot report size) | ||||||
|  |         # plot_max_num_digits: 5 | ||||||
|  | 
 | ||||||
|         # Settings for generated debug images |         # Settings for generated debug images | ||||||
|         images { |         images { | ||||||
|             format: JPEG |             format: JPEG | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai