Add support for multiple event writers in the same session

This commit is contained in:
allegroai 2019-07-20 22:02:19 +03:00
parent 51cc50e239
commit 00f873081a

View File

@ -1,4 +1,5 @@
import base64 import base64
import os
import sys import sys
import threading import threading
from collections import defaultdict from collections import defaultdict
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into TF SummaryWriter implementation that converts the tensorboard's summary into
Trains events and reports the events (metrics) for an Trains task (logger). Trains events and reports the events (metrics) for an Trains task (logger).
""" """
_add_lock = threading.Lock() _add_lock = threading.RLock()
_series_name_lookup = {} _series_name_lookup = {}
# store all the created tensorboard writers in the system
# this allows us to as weather a certain tile/series already exist on some EventWriter
# and if it does, then we add to the series name the last token from the logdir
# (so we can differentiate between the two)
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
_title_series_writers_lookup = {}
_event_writers_id_to_logdir = {}
@property @property
def variants(self): def variants(self):
return self._variants return self._variants
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
def prepare_report(self): def prepare_report(self):
return self.variants.copy() return self.variants.copy()
@staticmethod def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): logdir_header='series'):
""" """
Split a tf.summary tag line to variant and metric. Split a tf.summary tag line to variant and metric.
Variant is the first part of the split tag, metric is the second. Variant is the first part of the split tag, metric is the second.
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
:param str split_char: a character to split the tag on :param str split_char: a character to split the tag on
:param str join_char: a character to join the the splits :param str join_char: a character to join the the splits
:param str default_title: variant to use in case no variant can be inferred automatically :param str default_title: variant to use in case no variant can be inferred automatically
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
if 'title_last' then title=header title, if 'title' then title=title header
:return: (str, str) variant and metric :return: (str, str) variant and metric
""" """
splitted_tag = tag.split(split_char) splitted_tag = tag.split(split_char)
series = join_char.join(splitted_tag[-num_split_parts:]) series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
# check if we already decided that we need to change the title/series
graph_id = hash((title, series))
if graph_id in self._graph_name_lookup:
return self._graph_name_lookup[graph_id]
# check if someone other than us used this combination
with self._add_lock:
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
if not event_writer_id:
# put us there
self._title_series_writers_lookup[graph_id] = self._id
elif event_writer_id != self._id:
# if there is someone else, change our series name and store us
org_series = series
org_title = title
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
split_logddir = self._logdir.split(os.path.sep)
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
header = '/'.join(s for s in split_logddir if s in unique_logdir)
if logdir_header == 'series_last':
series = header + ': ' + series
elif logdir_header == 'series':
series = series + ' :' + header
elif logdir_header == 'title':
title = title + ' ' + header
else: # logdir_header == 'title_last':
title = header + ' ' + title
graph_id = hash((title, series))
# check if for some reason the new series is already occupied
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
if new_event_writer_id is not None and new_event_writer_id != self._id:
# well that's about it, nothing else we could do
if logdir_header == 'series_last':
series = str(self._logdir) + ': ' + org_series
elif logdir_header == 'series':
series = org_series + ' :' + str(self._logdir)
elif logdir_header == 'title':
title = org_title + ' ' + str(self._logdir)
else: # logdir_header == 'title_last':
title = str(self._logdir) + ' ' + org_title
graph_id = hash((title, series))
self._title_series_writers_lookup[graph_id] = self._id
# store for next time
self._graph_name_lookup[graph_id] = (title, series)
return title, series return title, series
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10, def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
histogram_granularity=50, max_keep_images=None): histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
""" """
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
""" """
# We are the events_writer, so that's what we'll pass # We are the events_writer, so that's what we'll pass
IsTensorboardInit.set_tensorboard_used() IsTensorboardInit.set_tensorboard_used()
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
self._id = hash(self._logdir)
self._event_writers_id_to_logdir[self._id] = self._logdir
self.max_keep_images = max_keep_images self.max_keep_images = max_keep_images
self.report_freq = report_freq self.report_freq = report_freq
self.image_report_freq = image_report_freq if image_report_freq else report_freq self.image_report_freq = image_report_freq if image_report_freq else report_freq
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
self._hist_report_cache = {} self._hist_report_cache = {}
self._hist_x_granularity = 50 self._hist_x_granularity = 50
self._max_step = 0 self._max_step = 0
self._graph_name_lookup = {}
def _decode_image(self, img_str, width, height, color_channels): def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException # noinspection PyBroadException
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
if img_data_np is None: if img_data_np is None:
return return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images') title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
if img_data_np.dtype != np.uint8: if img_data_np.dtype != np.uint8:
# assume scale 0-1 # assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8) img_data_np = (img_data_np * 255).astype(np.uint8)
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix) return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
def _add_scalar(self, tag, step, scalar_data): def _add_scalar(self, tag, step, scalar_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars') title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
# update scalar cache # update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0)) num, value = self._scalar_report_cache.get((title, series), (0, 0))
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
# Y-axis (rows) is iteration (from 0 to current Step) # Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit') # X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket') # Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms') title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
# get histograms from cache # get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains: if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains: if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
# patch the events writer field, and add a double Event Logger (Trains and original) # patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer'] base_eventwriter = __dict__['event_writer']
try:
logdir = base_eventwriter.get_logdir()
except Exception:
logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict) trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list # order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter]) __dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).debug(str(ex))
@staticmethod @staticmethod
def _get_event_writer(): def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task: if not PatchTensorFlowEager.__main_task:
return None return None
if PatchTensorFlowEager.__trains_event_writer is None: if PatchTensorFlowEager.__trains_event_writer is None:
try:
logdir = writer.get_logdir()
except Exception:
logdir = None
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter( PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict) logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer return PatchTensorFlowEager.__trains_event_writer
@staticmethod @staticmethod
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs): def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object):
pass pass
return model return model
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model