Tensorboard text logging report as debug samples (.txt files), instead of as console output.

This commit is contained in:
allegroai 2020-07-04 22:55:29 +03:00
parent 934771184d
commit 2f5b519cd8
4 changed files with 65 additions and 11 deletions

View File

@ -192,7 +192,8 @@ class UploadEvent(MetricsEventAdapter):
file_history_size=None, delete_after_upload=False, **kwargs):
# param override_filename: override uploaded file name (notice extension will be added from local path
# param override_filename_ext: override uploaded file extension
if image_data is not None and (not hasattr(image_data, 'shape') and not isinstance(image_data, six.BytesIO)):
if image_data is not None and (
not hasattr(image_data, 'shape') and not isinstance(image_data, (six.StringIO, six.BytesIO))):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
self._local_image_path = local_image_path
@ -263,7 +264,7 @@ class UploadEvent(MetricsEventAdapter):
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._file_history_size:
output = None
elif isinstance(self._image_data, six.BytesIO):
elif isinstance(self._image_data, (six.StringIO, six.BytesIO)):
output = self._image_data
elif self._image_data is not None:
image_data = self._image_data

View File

@ -277,7 +277,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type iter: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param stream: File stream
:param stream: File/String stream
:param file_extension: file extension to use when stream is passed
:param max_history: maximum number of files to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
@ -288,6 +288,9 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, stream) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, stream]')
if isinstance(stream, six.string_types):
stream = six.StringIO(stream)
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
file_history_size=max_history)
ev = MediaEvent(stream=stream, upload_uri=upload_uri, local_image_path=path,

View File

@ -51,6 +51,7 @@ class IsTensorboardInit(object):
return original_init(self, *args, **kwargs)
# noinspection PyProtectedMember
class WeightsGradientHistHelper(object):
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
self._logger = logger
@ -138,6 +139,7 @@ class WeightsGradientHistHelper(object):
if minmax is None:
minmax = hist[:, 0].min(), hist[:, 0].max()
else:
# noinspection PyUnresolvedReferences
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
# update the cache
@ -185,6 +187,7 @@ class WeightsGradientHistHelper(object):
camera=(-0.1, +1.3, 1.4))
# noinspection PyMethodMayBeStatic,PyProtectedMember,SpellCheckingInspection
class EventTrainsWriter(object):
"""
TF SummaryWriter implementation that converts the tensorboard's summary into
@ -347,6 +350,7 @@ class EventTrainsWriter(object):
image = np.asarray(im)
output.close()
if height > 0 and width > 0:
# noinspection PyArgumentList
val = image.reshape(height, width, -1).astype(np.uint8)
else:
val = image.astype(np.uint8)
@ -369,6 +373,7 @@ class EventTrainsWriter(object):
return val
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
# type: (str, int, np.ndarray, int) -> ()
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
@ -390,6 +395,7 @@ class EventTrainsWriter(object):
if img_data_np.ndim == 4:
dims = img_data_np.shape
stack_dim = int(np.sqrt(dims[0]))
# noinspection PyArgumentList
res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4))
tile_size = res.shape[0] * res.shape[1]
img_data_np = res.reshape(tile_size, tile_size, -1)
@ -555,6 +561,23 @@ class EventTrainsWriter(object):
max_history=self.max_keep_images,
)
def _add_text(self, tag, step, tensor_bytes):
# noinspection PyProtectedMember
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
auto_reduce_num_split=True,
force_add_prefix=self._logger._get_tensorboard_series_prefix())
step = self._fix_step_counter(title, series, step)
text = tensor_bytes.decode('utf-8', errors='replace')
self._logger.report_media(
title=title,
series=series,
iteration=step,
stream=six.StringIO(text),
file_extension='.txt',
max_history=self.max_keep_images,
)
@staticmethod
def _fix_step_counter(title, series, step):
key = (title, series)
@ -573,7 +596,7 @@ class EventTrainsWriter(object):
wraparound_counter['last_step'] = step
return step + wraparound_counter['adjust_counter']
def add_event(self, event, step=None, walltime=None, **kwargs):
def add_event(self, event, step=None, walltime=None, **_):
supported_metrics = {
'simpleValue', 'image', 'histo', 'tensor', 'audio'
}
@ -603,6 +626,7 @@ class EventTrainsWriter(object):
'event summary not found, message type unsupported: %s' % keys_list)
return
value_dicts = summary.get('value')
# noinspection PyUnusedLocal
walltime = walltime or msg_dict.get('step')
step = step or msg_dict.get('step')
if step is None:
@ -646,9 +670,8 @@ class EventTrainsWriter(object):
self._generic_tensor_type_name_lookup[tag] = plugin_type
self._add_audio(tag, step, None, tensor_bytes)
elif plugin_type == 'text':
# text, just print to console
text = tensor_bytes.decode('utf-8', errors='replace')
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
self._generic_tensor_type_name_lookup[tag] = plugin_type
self._add_text(tag, step, tensor_bytes)
else:
# we do not support it
pass
@ -700,6 +723,7 @@ class EventTrainsWriter(object):
# ~/torch/utils/tensorboard/summary.py
def _clean_tag(name):
import re as _re
# noinspection RegExpRedundantEscape
_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
if name is not None:
new_name = _INVALID_TAG_CHARACTERS.sub('_', name)
@ -711,6 +735,7 @@ class EventTrainsWriter(object):
return name
main_path = self._logdir
# noinspection PyBroadException
try:
main_path = _clean_tag(main_path)
origin_tag = main_path.rpartition("/")[2].replace(title_prefix, "", 1)
@ -723,6 +748,7 @@ class EventTrainsWriter(object):
return origin_tag
# noinspection PyCallingNonCallable
class ProxyEventsWriter(object):
def __init__(self, events):
IsTensorboardInit.set_tensorboard_used()
@ -771,6 +797,7 @@ class ProxyEventsWriter(object):
return ret
# noinspection PyPep8Naming
class PatchSummaryToEventTransformer(object):
__main_task = None
__original_getattribute = None
@ -785,6 +812,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod
def trains_object(self):
if isinstance(self.event_writer, ProxyEventsWriter):
# noinspection PyProtectedMember
trains_writer = [e for e in self.event_writer._events if isinstance(e, EventTrainsWriter)]
return trains_writer[0] if trains_writer else None
elif isinstance(self.event_writer, EventTrainsWriter):
@ -824,6 +852,7 @@ class PatchSummaryToEventTransformer(object):
try:
# only patch once
if PatchSummaryToEventTransformer._original_add_eventT is None:
# noinspection PyUnresolvedReferences
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
@ -838,6 +867,7 @@ class PatchSummaryToEventTransformer(object):
try:
# only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None:
# noinspection PyUnresolvedReferences
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
PatchSummaryToEventTransformer.__original_getattributeX = \
SummaryToEventTransformerX.__getattribute__
@ -869,6 +899,7 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains:
# noinspection PyBroadException
try:
logdir = self.get_logdir()
except Exception:
@ -887,6 +918,7 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains:
# noinspection PyBroadException
try:
logdir = self.get_logdir()
except Exception:
@ -924,6 +956,7 @@ class PatchSummaryToEventTransformer(object):
# patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer']
# noinspection PyBroadException
try:
logdir = base_eventwriter.get_logdir()
except Exception:
@ -984,6 +1017,7 @@ class PatchModelCheckPointCallback(object):
@staticmethod
def trains_object(self):
if isinstance(self.model, _ModelAdapter):
# noinspection PyProtectedMember
return self.model._output_model
if not self.__dict__.get('_trains_defaults'):
self.__dict__['_trains_defaults'] = {}
@ -1067,6 +1101,7 @@ class PatchModelCheckPointCallback(object):
return get_base(self, attr)
# noinspection PyProtectedMember,PyUnresolvedReferences
class PatchTensorFlowEager(object):
__main_task = None
__original_fn_scalar = None
@ -1171,6 +1206,11 @@ class PatchTensorFlowEager(object):
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy()
)
elif plugin_type.endswith('text'):
event_writer._add_text(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
tensor_bytes=tensor.numpy()
)
elif 'audio' in plugin_type:
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
for i, audio_bytes in enumerate(audio_bytes_list):
@ -1241,13 +1281,14 @@ class PatchTensorFlowEager(object):
max_keep_images=kwargs.get('max_images'))
# noinspection PyPep8Naming,SpellCheckingInspection
class PatchKerasModelIO(object):
__main_task = None
__patched_keras = None
__patched_tensorflow = None
@staticmethod
def update_current_task(task, **kwargs):
def update_current_task(task, **_):
PatchKerasModelIO.__main_task = task
PatchKerasModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
@ -1314,6 +1355,7 @@ class PatchKerasModelIO(object):
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
@ -1322,6 +1364,7 @@ class PatchKerasModelIO(object):
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
@ -1539,7 +1582,7 @@ class PatchTensorflowModelIO(object):
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
def update_current_task(task, **_):
PatchTensorflowModelIO.__main_task = task
PatchTensorflowModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
@ -1557,6 +1600,7 @@ class PatchTensorflowModelIO(object):
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
# noinspection PyUnresolvedReferences
from tensorflow.python.training.saver import Saver
# noinspection PyBroadException
try:
@ -1605,6 +1649,7 @@ class PatchTensorflowModelIO(object):
import tensorflow # noqa: F811
# actual import
from tensorflow.saved_model import load # noqa: F401
# noinspection PyUnresolvedReferences
import tensorflow.saved_model as saved_model_load
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
except ImportError:
@ -1617,6 +1662,7 @@ class PatchTensorflowModelIO(object):
# make sure we import the correct version of save
import tensorflow # noqa: F811
# actual import
# noinspection PyUnresolvedReferences
from tensorflow.saved_model import loader as loader1
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
except ImportError:
@ -1629,6 +1675,7 @@ class PatchTensorflowModelIO(object):
# make sure we import the correct version of save
import tensorflow # noqa: F811
# actual import
# noinspection PyUnresolvedReferences
from tensorflow.compat.v1.saved_model import loader as loader2
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
except ImportError:
@ -1772,7 +1819,7 @@ class PatchTensorflow2ModelIO(object):
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
def update_current_task(task, **_):
PatchTensorflow2ModelIO.__main_task = task
PatchTensorflow2ModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
@ -1812,6 +1859,7 @@ class PatchTensorflow2ModelIO(object):
def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs)
# store output Model
# noinspection PyBroadException
try:
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
@ -1827,6 +1875,7 @@ class PatchTensorflow2ModelIO(object):
# Hack: disabled
if False and running_remotely():
# register/load model weights
# noinspection PyBroadException
try:
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
@ -1838,6 +1887,7 @@ class PatchTensorflow2ModelIO(object):
# load model, if something is wrong, exception will be raised before we register the input model
model = original_fn(self, save_path, *args, **kwargs)
# register/load model weights
# noinspection PyBroadException
try:
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)

View File

@ -828,7 +828,7 @@ class Logger(object):
series, # type: str
iteration, # type: int
local_path=None, # type: Optional[str]
stream=None, # type: Optional[six.BytesIO]
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
file_extension=None, # type: Optional[str]
max_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool