diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index f4e083c6..b57cf0ef 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -11,6 +11,7 @@ import threading from ..debugging.log import LoggerRoot from ..config import running_remotely +from .import_bind import PostImportHookPatching class PatchedMatplotlib: @@ -125,11 +126,20 @@ class PatchedMatplotlib: @staticmethod def update_current_task(task): - if PatchedMatplotlib.patch_matplotlib(): + # make sure we have a default value + if PatchedMatplotlib._global_image_counter_limit is None: + from ..config import config + PatchedMatplotlib._global_image_counter_limit = config.get('metric.matplotlib_untitled_history_size', 100) + + # if we already patched it, just update the current task + if PatchedMatplotlib._patched_original_plot is not None: + PatchedMatplotlib._current_task = task + # if matplotlib is not loaded yet, get a callback hook + elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules: + PatchedMatplotlib._current_task = task + PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) + elif PatchedMatplotlib.patch_matplotlib(): PatchedMatplotlib._current_task = task - if PatchedMatplotlib._global_image_counter_limit is None: - from ..config import config - PatchedMatplotlib._global_image_counter_limit = config.get('metric.matplotlib_untitled_history_size', 100) @staticmethod def patched_imshow(*args, **kw): @@ -172,7 +182,7 @@ class PatchedMatplotlib: except Exception: pass ret = PatchedMatplotlib._patched_original_plot(*args, **kw) - if PatchedMatplotlib._current_task and running_remotely(): + if PatchedMatplotlib._current_task and sys.modules['matplotlib'].rcParams['backend'] == 'agg': # clear the current plot, because no one else will # noinspection PyBroadException try: