Fix TensorBoard multiple Task.init()/close() calls in the same process (issue #312)

This commit is contained in:
allegroai 2021-02-21 14:58:40 +02:00
parent c6d22c2d0a
commit ee9dfb8b35

View File

@ -28,6 +28,9 @@ except ImportError:
class TensorflowBinding(object):
@classmethod
def update_current_task(cls, task):
if not task:
IsTensorboardInit.clear_tensorboard_used()
EventTrainsWriter.update_current_task(task)
PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task)
@ -46,6 +49,10 @@ class IsTensorboardInit(object):
def set_tensorboard_used(cls):
cls._tensorboard_initialized = True
@classmethod
def clear_tensorboard_used(cls):
cls._tensorboard_initialized = False
@staticmethod
def _patched_tb__init__(original_init, self, *args, **kwargs):
IsTensorboardInit._tensorboard_initialized = True
@ -194,6 +201,7 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into
ClearML events and reports the events (metrics) for an ClearML task (logger).
"""
__main_task = None
_add_lock = threading.RLock()
_series_name_lookup = {}
@ -750,6 +758,16 @@ class EventTrainsWriter(object):
origin_tag = ""
return origin_tag
@classmethod
def update_current_task(cls, task):
if cls.__main_task != task:
with cls._add_lock:
cls._series_name_lookup = {}
cls._title_series_writers_lookup = {}
cls._event_writers_id_to_logdir = {}
cls._title_series_wraparound_counter = {}
cls.__main_task = task
# noinspection PyCallingNonCallable
class ProxyEventsWriter(object):
@ -1118,6 +1136,9 @@ class PatchTensorFlowEager(object):
@staticmethod
def update_current_task(task, **kwargs):
if task != PatchTensorFlowEager.__main_task:
PatchTensorFlowEager.__trains_event_writer = {}
PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer