From ee9dfb8b358a0cce26e9d94210c4825e1495c9a5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 21 Feb 2021 14:58:40 +0200 Subject: [PATCH] Fix TensorBoard multiple Task.init()/close() calls in the same process (issue #312) --- clearml/binding/frameworks/tensorflow_bind.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index 037a9384..a00342b8 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -28,6 +28,9 @@ except ImportError: class TensorflowBinding(object): @classmethod def update_current_task(cls, task): + if not task: + IsTensorboardInit.clear_tensorboard_used() + EventTrainsWriter.update_current_task(task) PatchSummaryToEventTransformer.update_current_task(task) PatchTensorFlowEager.update_current_task(task) PatchKerasModelIO.update_current_task(task) @@ -46,6 +49,10 @@ class IsTensorboardInit(object): def set_tensorboard_used(cls): cls._tensorboard_initialized = True + @classmethod + def clear_tensorboard_used(cls): + cls._tensorboard_initialized = False + @staticmethod def _patched_tb__init__(original_init, self, *args, **kwargs): IsTensorboardInit._tensorboard_initialized = True @@ -194,6 +201,7 @@ class EventTrainsWriter(object): TF SummaryWriter implementation that converts the tensorboard's summary into ClearML events and reports the events (metrics) for an ClearML task (logger). """ + __main_task = None _add_lock = threading.RLock() _series_name_lookup = {} @@ -750,6 +758,16 @@ class EventTrainsWriter(object): origin_tag = "" return origin_tag + @classmethod + def update_current_task(cls, task): + if cls.__main_task != task: + with cls._add_lock: + cls._series_name_lookup = {} + cls._title_series_writers_lookup = {} + cls._event_writers_id_to_logdir = {} + cls._title_series_wraparound_counter = {} + cls.__main_task = task + # noinspection PyCallingNonCallable class ProxyEventsWriter(object): @@ -1118,6 +1136,9 @@ class PatchTensorFlowEager(object): @staticmethod def update_current_task(task, **kwargs): + if task != PatchTensorFlowEager.__main_task: + PatchTensorFlowEager.__trains_event_writer = {} + PatchTensorFlowEager.defaults_dict.update(kwargs) PatchTensorFlowEager.__main_task = task # make sure we patched the SummaryToEventTransformer