Add support for multiple event writers in the same session

This commit is contained in:
allegroai 2019-07-20 22:02:19 +03:00
parent 51cc50e239
commit 00f873081a

View File

@ -1,4 +1,5 @@
import base64
import os
import sys
import threading
from collections import defaultdict
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into
Trains events and reports the events (metrics) for an Trains task (logger).
"""
_add_lock = threading.Lock()
_add_lock = threading.RLock()
_series_name_lookup = {}
# store all the created tensorboard writers in the system
# this allows us to as weather a certain tile/series already exist on some EventWriter
# and if it does, then we add to the series name the last token from the logdir
# (so we can differentiate between the two)
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
_title_series_writers_lookup = {}
_event_writers_id_to_logdir = {}
@property
def variants(self):
return self._variants
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
def prepare_report(self):
return self.variants.copy()
@staticmethod
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
logdir_header='series'):
"""
Split a tf.summary tag line to variant and metric.
Variant is the first part of the split tag, metric is the second.
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
:param str split_char: a character to split the tag on
:param str join_char: a character to join the the splits
:param str default_title: variant to use in case no variant can be inferred automatically
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
if 'title_last' then title=header title, if 'title' then title=title header
:return: (str, str) variant and metric
"""
splitted_tag = tag.split(split_char)
series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
# check if we already decided that we need to change the title/series
graph_id = hash((title, series))
if graph_id in self._graph_name_lookup:
return self._graph_name_lookup[graph_id]
# check if someone other than us used this combination
with self._add_lock:
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
if not event_writer_id:
# put us there
self._title_series_writers_lookup[graph_id] = self._id
elif event_writer_id != self._id:
# if there is someone else, change our series name and store us
org_series = series
org_title = title
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
split_logddir = self._logdir.split(os.path.sep)
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
header = '/'.join(s for s in split_logddir if s in unique_logdir)
if logdir_header == 'series_last':
series = header + ': ' + series
elif logdir_header == 'series':
series = series + ' :' + header
elif logdir_header == 'title':
title = title + ' ' + header
else: # logdir_header == 'title_last':
title = header + ' ' + title
graph_id = hash((title, series))
# check if for some reason the new series is already occupied
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
if new_event_writer_id is not None and new_event_writer_id != self._id:
# well that's about it, nothing else we could do
if logdir_header == 'series_last':
series = str(self._logdir) + ': ' + org_series
elif logdir_header == 'series':
series = org_series + ' :' + str(self._logdir)
elif logdir_header == 'title':
title = org_title + ' ' + str(self._logdir)
else: # logdir_header == 'title_last':
title = str(self._logdir) + ' ' + org_title
graph_id = hash((title, series))
self._title_series_writers_lookup[graph_id] = self._id
# store for next time
self._graph_name_lookup[graph_id] = (title, series)
return title, series
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
histogram_granularity=50, max_keep_images=None):
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
"""
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
"""
# We are the events_writer, so that's what we'll pass
IsTensorboardInit.set_tensorboard_used()
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
self._id = hash(self._logdir)
self._event_writers_id_to_logdir[self._id] = self._logdir
self.max_keep_images = max_keep_images
self.report_freq = report_freq
self.image_report_freq = image_report_freq if image_report_freq else report_freq
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
self._hist_report_cache = {}
self._hist_x_granularity = 50
self._max_step = 0
self._graph_name_lookup = {}
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
if img_data_np is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8)
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
def _add_scalar(self, tag, step, scalar_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
# update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0))
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
# get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
# patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer']
try:
logdir = base_eventwriter.get_logdir()
except Exception:
logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
getLogger(TrainsFrameworkAdapter).debug(str(ex))
@staticmethod
def _get_event_writer():
def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task:
return None
if PatchTensorFlowEager.__trains_event_writer is None:
try:
logdir = writer.get_logdir()
except Exception:
logdir = None
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer
@staticmethod
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object):
pass
return model
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model