Fix TensorFlow NaN/Inf values support

This commit is contained in:
allegroai 2020-03-20 10:26:43 +02:00
parent babaf9f1ce
commit b4050ecf25

View File

@ -287,7 +287,16 @@ class EventTrainsWriter(object):
series = possible_tag or series
# update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0))
self._scalar_report_cache[(title, series)] = (num + 1, value + scalar_data)
# nan outputs is a string, it's probably a NaN
if isinstance(scalar_data, six.string_types):
try:
scalar_data = float(scalar_data)
except:
scalar_data = float('nan')
# nan outputs nan
self._scalar_report_cache[(title, series)] = \
(num + 1,
(value + scalar_data) if scalar_data == scalar_data else scalar_data)
# only report images every specific interval
if step % self.report_freq != 0: