mirror of
https://github.com/clearml/clearml
synced 2025-04-03 12:31:11 +00:00
Add support for multiple event writers in the same session
This commit is contained in:
parent
51cc50e239
commit
00f873081a
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
|
||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||
Trains events and reports the events (metrics) for an Trains task (logger).
|
||||
"""
|
||||
_add_lock = threading.Lock()
|
||||
_add_lock = threading.RLock()
|
||||
_series_name_lookup = {}
|
||||
|
||||
# store all the created tensorboard writers in the system
|
||||
# this allows us to as weather a certain tile/series already exist on some EventWriter
|
||||
# and if it does, then we add to the series name the last token from the logdir
|
||||
# (so we can differentiate between the two)
|
||||
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
|
||||
_title_series_writers_lookup = {}
|
||||
_event_writers_id_to_logdir = {}
|
||||
|
||||
@property
|
||||
def variants(self):
|
||||
return self._variants
|
||||
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
|
||||
def prepare_report(self):
|
||||
return self.variants.copy()
|
||||
|
||||
@staticmethod
|
||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
||||
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
|
||||
logdir_header='series'):
|
||||
"""
|
||||
Split a tf.summary tag line to variant and metric.
|
||||
Variant is the first part of the split tag, metric is the second.
|
||||
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
|
||||
:param str split_char: a character to split the tag on
|
||||
:param str join_char: a character to join the the splits
|
||||
:param str default_title: variant to use in case no variant can be inferred automatically
|
||||
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
|
||||
if 'title_last' then title=header title, if 'title' then title=title header
|
||||
:return: (str, str) variant and metric
|
||||
"""
|
||||
splitted_tag = tag.split(split_char)
|
||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||
|
||||
# check if we already decided that we need to change the title/series
|
||||
graph_id = hash((title, series))
|
||||
if graph_id in self._graph_name_lookup:
|
||||
return self._graph_name_lookup[graph_id]
|
||||
|
||||
# check if someone other than us used this combination
|
||||
with self._add_lock:
|
||||
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
|
||||
if not event_writer_id:
|
||||
# put us there
|
||||
self._title_series_writers_lookup[graph_id] = self._id
|
||||
elif event_writer_id != self._id:
|
||||
# if there is someone else, change our series name and store us
|
||||
org_series = series
|
||||
org_title = title
|
||||
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
|
||||
split_logddir = self._logdir.split(os.path.sep)
|
||||
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
|
||||
header = '/'.join(s for s in split_logddir if s in unique_logdir)
|
||||
if logdir_header == 'series_last':
|
||||
series = header + ': ' + series
|
||||
elif logdir_header == 'series':
|
||||
series = series + ' :' + header
|
||||
elif logdir_header == 'title':
|
||||
title = title + ' ' + header
|
||||
else: # logdir_header == 'title_last':
|
||||
title = header + ' ' + title
|
||||
graph_id = hash((title, series))
|
||||
# check if for some reason the new series is already occupied
|
||||
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
|
||||
if new_event_writer_id is not None and new_event_writer_id != self._id:
|
||||
# well that's about it, nothing else we could do
|
||||
if logdir_header == 'series_last':
|
||||
series = str(self._logdir) + ': ' + org_series
|
||||
elif logdir_header == 'series':
|
||||
series = org_series + ' :' + str(self._logdir)
|
||||
elif logdir_header == 'title':
|
||||
title = org_title + ' ' + str(self._logdir)
|
||||
else: # logdir_header == 'title_last':
|
||||
title = str(self._logdir) + ' ' + org_title
|
||||
graph_id = hash((title, series))
|
||||
|
||||
self._title_series_writers_lookup[graph_id] = self._id
|
||||
|
||||
# store for next time
|
||||
self._graph_name_lookup[graph_id] = (title, series)
|
||||
return title, series
|
||||
|
||||
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
|
||||
histogram_granularity=50, max_keep_images=None):
|
||||
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
|
||||
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
|
||||
"""
|
||||
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
||||
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
||||
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
|
||||
"""
|
||||
# We are the events_writer, so that's what we'll pass
|
||||
IsTensorboardInit.set_tensorboard_used()
|
||||
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
|
||||
self._id = hash(self._logdir)
|
||||
self._event_writers_id_to_logdir[self._id] = self._logdir
|
||||
self.max_keep_images = max_keep_images
|
||||
self.report_freq = report_freq
|
||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
|
||||
self._hist_report_cache = {}
|
||||
self._hist_x_granularity = 50
|
||||
self._max_step = 0
|
||||
self._graph_name_lookup = {}
|
||||
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
|
||||
if img_data_np is None:
|
||||
return
|
||||
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
|
||||
if img_data_np.dtype != np.uint8:
|
||||
# assume scale 0-1
|
||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
|
||||
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
||||
|
||||
def _add_scalar(self, tag, step, scalar_data):
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
|
||||
|
||||
# update scalar cache
|
||||
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
||||
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
|
||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||
# Z-axis actual value (interpolated 'bucket')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||
logdir_header='series')
|
||||
|
||||
# get histograms from cache
|
||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
|
||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
try:
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
|
||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
try:
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
|
||||
|
||||
# patch the events writer field, and add a double Event Logger (Trains and original)
|
||||
base_eventwriter = __dict__['event_writer']
|
||||
try:
|
||||
logdir = base_eventwriter.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
logdir=logdir, **defaults_dict)
|
||||
|
||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
||||
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer():
|
||||
def _get_event_writer(writer):
|
||||
if not PatchTensorFlowEager.__main_task:
|
||||
return None
|
||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
||||
try:
|
||||
logdir = writer.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||
**PatchTensorFlowEager.defaults_dict)
|
||||
return PatchTensorFlowEager.__trains_event_writer
|
||||
|
||||
@staticmethod
|
||||
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||
@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object):
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class PatchPyTorchModelIO(object):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchPyTorchModelIO.__main_task = task
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchPyTorchModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'torch' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import torch
|
||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching pytorch')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user