mirror of
https://github.com/clearml/clearml
synced 2025-04-07 06:04:25 +00:00
Add support for multiple event writers in the same session
This commit is contained in:
parent
51cc50e239
commit
00f873081a
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
|
|||||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||||
Trains events and reports the events (metrics) for an Trains task (logger).
|
Trains events and reports the events (metrics) for an Trains task (logger).
|
||||||
"""
|
"""
|
||||||
_add_lock = threading.Lock()
|
_add_lock = threading.RLock()
|
||||||
_series_name_lookup = {}
|
_series_name_lookup = {}
|
||||||
|
|
||||||
|
# store all the created tensorboard writers in the system
|
||||||
|
# this allows us to as weather a certain tile/series already exist on some EventWriter
|
||||||
|
# and if it does, then we add to the series name the last token from the logdir
|
||||||
|
# (so we can differentiate between the two)
|
||||||
|
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
|
||||||
|
_title_series_writers_lookup = {}
|
||||||
|
_event_writers_id_to_logdir = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variants(self):
|
def variants(self):
|
||||||
return self._variants
|
return self._variants
|
||||||
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
|
|||||||
def prepare_report(self):
|
def prepare_report(self):
|
||||||
return self.variants.copy()
|
return self.variants.copy()
|
||||||
|
|
||||||
@staticmethod
|
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
|
||||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
logdir_header='series'):
|
||||||
"""
|
"""
|
||||||
Split a tf.summary tag line to variant and metric.
|
Split a tf.summary tag line to variant and metric.
|
||||||
Variant is the first part of the split tag, metric is the second.
|
Variant is the first part of the split tag, metric is the second.
|
||||||
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
|
|||||||
:param str split_char: a character to split the tag on
|
:param str split_char: a character to split the tag on
|
||||||
:param str join_char: a character to join the the splits
|
:param str join_char: a character to join the the splits
|
||||||
:param str default_title: variant to use in case no variant can be inferred automatically
|
:param str default_title: variant to use in case no variant can be inferred automatically
|
||||||
|
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
|
||||||
|
if 'title_last' then title=header title, if 'title' then title=title header
|
||||||
:return: (str, str) variant and metric
|
:return: (str, str) variant and metric
|
||||||
"""
|
"""
|
||||||
splitted_tag = tag.split(split_char)
|
splitted_tag = tag.split(split_char)
|
||||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||||
|
|
||||||
|
# check if we already decided that we need to change the title/series
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
if graph_id in self._graph_name_lookup:
|
||||||
|
return self._graph_name_lookup[graph_id]
|
||||||
|
|
||||||
|
# check if someone other than us used this combination
|
||||||
|
with self._add_lock:
|
||||||
|
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
|
||||||
|
if not event_writer_id:
|
||||||
|
# put us there
|
||||||
|
self._title_series_writers_lookup[graph_id] = self._id
|
||||||
|
elif event_writer_id != self._id:
|
||||||
|
# if there is someone else, change our series name and store us
|
||||||
|
org_series = series
|
||||||
|
org_title = title
|
||||||
|
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
|
||||||
|
split_logddir = self._logdir.split(os.path.sep)
|
||||||
|
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
|
||||||
|
header = '/'.join(s for s in split_logddir if s in unique_logdir)
|
||||||
|
if logdir_header == 'series_last':
|
||||||
|
series = header + ': ' + series
|
||||||
|
elif logdir_header == 'series':
|
||||||
|
series = series + ' :' + header
|
||||||
|
elif logdir_header == 'title':
|
||||||
|
title = title + ' ' + header
|
||||||
|
else: # logdir_header == 'title_last':
|
||||||
|
title = header + ' ' + title
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
# check if for some reason the new series is already occupied
|
||||||
|
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
|
||||||
|
if new_event_writer_id is not None and new_event_writer_id != self._id:
|
||||||
|
# well that's about it, nothing else we could do
|
||||||
|
if logdir_header == 'series_last':
|
||||||
|
series = str(self._logdir) + ': ' + org_series
|
||||||
|
elif logdir_header == 'series':
|
||||||
|
series = org_series + ' :' + str(self._logdir)
|
||||||
|
elif logdir_header == 'title':
|
||||||
|
title = org_title + ' ' + str(self._logdir)
|
||||||
|
else: # logdir_header == 'title_last':
|
||||||
|
title = str(self._logdir) + ' ' + org_title
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
|
||||||
|
self._title_series_writers_lookup[graph_id] = self._id
|
||||||
|
|
||||||
|
# store for next time
|
||||||
|
self._graph_name_lookup[graph_id] = (title, series)
|
||||||
return title, series
|
return title, series
|
||||||
|
|
||||||
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
|
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
|
||||||
histogram_granularity=50, max_keep_images=None):
|
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
|
||||||
"""
|
"""
|
||||||
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
||||||
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
||||||
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
|
|||||||
"""
|
"""
|
||||||
# We are the events_writer, so that's what we'll pass
|
# We are the events_writer, so that's what we'll pass
|
||||||
IsTensorboardInit.set_tensorboard_used()
|
IsTensorboardInit.set_tensorboard_used()
|
||||||
|
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
|
||||||
|
self._id = hash(self._logdir)
|
||||||
|
self._event_writers_id_to_logdir[self._id] = self._logdir
|
||||||
self.max_keep_images = max_keep_images
|
self.max_keep_images = max_keep_images
|
||||||
self.report_freq = report_freq
|
self.report_freq = report_freq
|
||||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||||
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
|
|||||||
self._hist_report_cache = {}
|
self._hist_report_cache = {}
|
||||||
self._hist_x_granularity = 50
|
self._hist_x_granularity = 50
|
||||||
self._max_step = 0
|
self._max_step = 0
|
||||||
|
self._graph_name_lookup = {}
|
||||||
|
|
||||||
def _decode_image(self, img_str, width, height, color_channels):
|
def _decode_image(self, img_str, width, height, color_channels):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
|
|||||||
if img_data_np is None:
|
if img_data_np is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
|
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
|
||||||
if img_data_np.dtype != np.uint8:
|
if img_data_np.dtype != np.uint8:
|
||||||
# assume scale 0-1
|
# assume scale 0-1
|
||||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||||
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
|
|||||||
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
||||||
|
|
||||||
def _add_scalar(self, tag, step, scalar_data):
|
def _add_scalar(self, tag, step, scalar_data):
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
|
||||||
|
|
||||||
# update scalar cache
|
# update scalar cache
|
||||||
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
||||||
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
|
|||||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||||
# Z-axis actual value (interpolated 'bucket')
|
# Z-axis actual value (interpolated 'bucket')
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||||
|
logdir_header='series')
|
||||||
|
|
||||||
# get histograms from cache
|
# get histograms from cache
|
||||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||||
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
try:
|
||||||
|
logdir = self.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
try:
|
||||||
|
logdir = self.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
|
|
||||||
# patch the events writer field, and add a double Event Logger (Trains and original)
|
# patch the events writer field, and add a double Event Logger (Trains and original)
|
||||||
base_eventwriter = __dict__['event_writer']
|
base_eventwriter = __dict__['event_writer']
|
||||||
|
try:
|
||||||
|
logdir = base_eventwriter.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
|
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
|
logdir=logdir, **defaults_dict)
|
||||||
|
|
||||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||||
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
||||||
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
|
|||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_event_writer():
|
def _get_event_writer(writer):
|
||||||
if not PatchTensorFlowEager.__main_task:
|
if not PatchTensorFlowEager.__main_task:
|
||||||
return None
|
return None
|
||||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
if PatchTensorFlowEager.__trains_event_writer is None:
|
||||||
|
try:
|
||||||
|
logdir = writer.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
||||||
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
|
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||||
|
**PatchTensorFlowEager.defaults_dict)
|
||||||
return PatchTensorFlowEager.__trains_event_writer
|
return PatchTensorFlowEager.__trains_event_writer
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||||
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||||
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||||
@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object):
|
|||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class PatchPyTorchModelIO(object):
|
|
||||||
__main_task = None
|
|
||||||
__patched = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_current_task(task, **kwargs):
|
|
||||||
PatchPyTorchModelIO.__main_task = task
|
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patch_model_io():
|
|
||||||
if PatchPyTorchModelIO.__patched:
|
|
||||||
return
|
|
||||||
|
|
||||||
if 'torch' not in sys.modules:
|
|
||||||
return
|
|
||||||
|
|
||||||
PatchPyTorchModelIO.__patched = True
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
# hack: make sure tensorflow.__init__ is called
|
|
||||||
import torch
|
|
||||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
|
||||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass # print('Failed patching pytorch')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
|
||||||
filename = f
|
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
f.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
filename = None
|
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
model_name = Path(filename).stem
|
|
||||||
except Exception:
|
|
||||||
model_name = None
|
|
||||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
|
||||||
singlefile=True, model_name=model_name)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
|
||||||
if isinstance(f, six.string_types):
|
|
||||||
filename = f
|
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
else:
|
|
||||||
filename = None
|
|
||||||
|
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
|
||||||
return original_fn(f, *args, **kwargs)
|
|
||||||
|
|
||||||
# register input model
|
|
||||||
empty = _Empty()
|
|
||||||
if running_remotely():
|
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
|
||||||
PatchPyTorchModelIO.__main_task)
|
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
# try to load model before registering, in case we fail
|
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
|
||||||
PatchPyTorchModelIO.__main_task)
|
|
||||||
|
|
||||||
if empty.trains_in_model:
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
model.trains_in_model = empty.trains_in_model
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return model
|
|
||||||
|
Loading…
Reference in New Issue
Block a user