From 00f873081a2d4b83c64f45a1c8091227da8283a7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 22:02:19 +0300 Subject: [PATCH] Add support for multiple event writers in the same session --- trains/binding/frameworks/tensorflow_bind.py | 203 +++++++++---------- 1 file changed, 97 insertions(+), 106 deletions(-) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 1f8697cd..55c6fc1b 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1,4 +1,5 @@ import base64 +import os import sys import threading from collections import defaultdict @@ -44,9 +45,17 @@ class EventTrainsWriter(object): TF SummaryWriter implementation that converts the tensorboard's summary into Trains events and reports the events (metrics) for an Trains task (logger). """ - _add_lock = threading.Lock() + _add_lock = threading.RLock() _series_name_lookup = {} + # store all the created tensorboard writers in the system + # this allows us to as weather a certain tile/series already exist on some EventWriter + # and if it does, then we add to the series name the last token from the logdir + # (so we can differentiate between the two) + # key, value: key=hash(title, graph), value=EventTrainsWriter._id + _title_series_writers_lookup = {} + _event_writers_id_to_logdir = {} + @property def variants(self): return self._variants @@ -54,8 +63,8 @@ class EventTrainsWriter(object): def prepare_report(self): return self.variants.copy() - @staticmethod - def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): + def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant', + logdir_header='series'): """ Split a tf.summary tag line to variant and metric. Variant is the first part of the split tag, metric is the second. @@ -64,15 +73,64 @@ class EventTrainsWriter(object): :param str split_char: a character to split the tag on :param str join_char: a character to join the the splits :param str default_title: variant to use in case no variant can be inferred automatically + :param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header, + if 'title_last' then title=header title, if 'title' then title=title header :return: (str, str) variant and metric """ splitted_tag = tag.split(split_char) series = join_char.join(splitted_tag[-num_split_parts:]) title = join_char.join(splitted_tag[:-num_split_parts]) or default_title + + # check if we already decided that we need to change the title/series + graph_id = hash((title, series)) + if graph_id in self._graph_name_lookup: + return self._graph_name_lookup[graph_id] + + # check if someone other than us used this combination + with self._add_lock: + event_writer_id = self._title_series_writers_lookup.get(graph_id, None) + if not event_writer_id: + # put us there + self._title_series_writers_lookup[graph_id] = self._id + elif event_writer_id != self._id: + # if there is someone else, change our series name and store us + org_series = series + org_title = title + other_logdir = self._event_writers_id_to_logdir[event_writer_id] + split_logddir = self._logdir.split(os.path.sep) + unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep)) + header = '/'.join(s for s in split_logddir if s in unique_logdir) + if logdir_header == 'series_last': + series = header + ': ' + series + elif logdir_header == 'series': + series = series + ' :' + header + elif logdir_header == 'title': + title = title + ' ' + header + else: # logdir_header == 'title_last': + title = header + ' ' + title + graph_id = hash((title, series)) + # check if for some reason the new series is already occupied + new_event_writer_id = self._title_series_writers_lookup.get(graph_id) + if new_event_writer_id is not None and new_event_writer_id != self._id: + # well that's about it, nothing else we could do + if logdir_header == 'series_last': + series = str(self._logdir) + ': ' + org_series + elif logdir_header == 'series': + series = org_series + ' :' + str(self._logdir) + elif logdir_header == 'title': + title = org_title + ' ' + str(self._logdir) + else: # logdir_header == 'title_last': + title = str(self._logdir) + ' ' + org_title + graph_id = hash((title, series)) + + self._title_series_writers_lookup[graph_id] = self._id + + # store for next time + self._graph_name_lookup[graph_id] = (title, series) return title, series - def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10, - histogram_granularity=50, max_keep_images=None): + def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None, + histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None): """ Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter @@ -87,6 +145,9 @@ class EventTrainsWriter(object): """ # We are the events_writer, so that's what we'll pass IsTensorboardInit.set_tensorboard_used() + self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir)) + self._id = hash(self._logdir) + self._event_writers_id_to_logdir[self._id] = self._logdir self.max_keep_images = max_keep_images self.report_freq = report_freq self.image_report_freq = image_report_freq if image_report_freq else report_freq @@ -99,6 +160,7 @@ class EventTrainsWriter(object): self._hist_report_cache = {} self._hist_x_granularity = 50 self._max_step = 0 + self._graph_name_lookup = {} def _decode_image(self, img_str, width, height, color_channels): # noinspection PyBroadException @@ -131,7 +193,7 @@ class EventTrainsWriter(object): if img_data_np is None: return - title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images') + title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title') if img_data_np.dtype != np.uint8: # assume scale 0-1 img_data_np = (img_data_np * 255).astype(np.uint8) @@ -168,7 +230,7 @@ class EventTrainsWriter(object): return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix) def _add_scalar(self, tag, step, scalar_data): - title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars') + title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last') # update scalar cache num, value = self._scalar_report_cache.get((title, series), (0, 0)) @@ -216,7 +278,8 @@ class EventTrainsWriter(object): # Y-axis (rows) is iteration (from 0 to current Step) # X-axis averaged bins (conformed sample 'bucketLimit') # Z-axis actual value (interpolated 'bucket') - title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms') + title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms', + logdir_header='series') # get histograms from cache hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) @@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object): if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) if not self.trains: + try: + logdir = self.get_logdir() + except Exception: + logdir = None self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) @@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object): if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) if not self.trains: + try: + logdir = self.get_logdir() + except Exception: + logdir = None self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) @@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object): # patch the events writer field, and add a double Event Logger (Trains and original) base_eventwriter = __dict__['event_writer'] + try: + logdir = base_eventwriter.get_logdir() + except Exception: + logdir = None defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict - trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict) + trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), + logdir=logdir, **defaults_dict) # order is important, the return value of ProxyEventsWriter is the last object in the list __dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter]) @@ -798,12 +874,17 @@ class PatchTensorFlowEager(object): getLogger(TrainsFrameworkAdapter).debug(str(ex)) @staticmethod - def _get_event_writer(): + def _get_event_writer(writer): if not PatchTensorFlowEager.__main_task: return None if PatchTensorFlowEager.__trains_event_writer is None: + try: + logdir = writer.get_logdir() + except Exception: + logdir = None PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter( - logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict) + logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir, + **PatchTensorFlowEager.defaults_dict) return PatchTensorFlowEager.__trains_event_writer @staticmethod @@ -812,7 +893,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) @@ -822,7 +903,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_hist_summary(writer, step, tag, values, name, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) @@ -832,7 +913,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), @@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object): pass return model - -class PatchPyTorchModelIO(object): - __main_task = None - __patched = None - - @staticmethod - def update_current_task(task, **kwargs): - PatchPyTorchModelIO.__main_task = task - PatchPyTorchModelIO._patch_model_io() - PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) - - @staticmethod - def _patch_model_io(): - if PatchPyTorchModelIO.__patched: - return - - if 'torch' not in sys.modules: - return - - PatchPyTorchModelIO.__patched = True - # noinspection PyBroadException - try: - # hack: make sure tensorflow.__init__ is called - import torch - torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) - torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) - except ImportError: - pass - except Exception: - pass # print('Failed patching pytorch') - - @staticmethod - def _save(original_fn, obj, f, *args, **kwargs): - ret = original_fn(obj, f, *args, **kwargs) - if not PatchPyTorchModelIO.__main_task: - return ret - - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - # noinspection PyBroadException - try: - f.flush() - except Exception: - pass - else: - filename = None - - # give the model a descriptive name based on the file name - # noinspection PyBroadException - try: - model_name = Path(filename).stem - except Exception: - model_name = None - WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, - singlefile=True, model_name=model_name) - return ret - - @staticmethod - def _load(original_fn, f, *args, **kwargs): - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - else: - filename = None - - if not PatchPyTorchModelIO.__main_task: - return original_fn(f, *args, **kwargs) - - # register input model - empty = _Empty() - if running_remotely(): - filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) - model = original_fn(filename or f, *args, **kwargs) - else: - # try to load model before registering, in case we fail - model = original_fn(filename or f, *args, **kwargs) - WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) - - if empty.trains_in_model: - # noinspection PyBroadException - try: - model.trains_in_model = empty.trains_in_model - except Exception: - pass - return model