Fix TensorBoard multiple Task.init()/close() calls in the same process (issue #312)

This commit is contained in:
allegroai 2021-02-21 14:58:40 +02:00
parent c6d22c2d0a
commit ee9dfb8b35

View File

@ -28,6 +28,9 @@ except ImportError:
class TensorflowBinding(object): class TensorflowBinding(object):
@classmethod @classmethod
def update_current_task(cls, task): def update_current_task(cls, task):
if not task:
IsTensorboardInit.clear_tensorboard_used()
EventTrainsWriter.update_current_task(task)
PatchSummaryToEventTransformer.update_current_task(task) PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task) PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task) PatchKerasModelIO.update_current_task(task)
@ -46,6 +49,10 @@ class IsTensorboardInit(object):
def set_tensorboard_used(cls): def set_tensorboard_used(cls):
cls._tensorboard_initialized = True cls._tensorboard_initialized = True
@classmethod
def clear_tensorboard_used(cls):
cls._tensorboard_initialized = False
@staticmethod @staticmethod
def _patched_tb__init__(original_init, self, *args, **kwargs): def _patched_tb__init__(original_init, self, *args, **kwargs):
IsTensorboardInit._tensorboard_initialized = True IsTensorboardInit._tensorboard_initialized = True
@ -194,6 +201,7 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into TF SummaryWriter implementation that converts the tensorboard's summary into
ClearML events and reports the events (metrics) for an ClearML task (logger). ClearML events and reports the events (metrics) for an ClearML task (logger).
""" """
__main_task = None
_add_lock = threading.RLock() _add_lock = threading.RLock()
_series_name_lookup = {} _series_name_lookup = {}
@ -750,6 +758,16 @@ class EventTrainsWriter(object):
origin_tag = "" origin_tag = ""
return origin_tag return origin_tag
@classmethod
def update_current_task(cls, task):
if cls.__main_task != task:
with cls._add_lock:
cls._series_name_lookup = {}
cls._title_series_writers_lookup = {}
cls._event_writers_id_to_logdir = {}
cls._title_series_wraparound_counter = {}
cls.__main_task = task
# noinspection PyCallingNonCallable # noinspection PyCallingNonCallable
class ProxyEventsWriter(object): class ProxyEventsWriter(object):
@ -1118,6 +1136,9 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
if task != PatchTensorFlowEager.__main_task:
PatchTensorFlowEager.__trains_event_writer = {}
PatchTensorFlowEager.defaults_dict.update(kwargs) PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer # make sure we patched the SummaryToEventTransformer