matplotlib with no logger should not use the last iteration reported by the resource monitor

This commit is contained in:
allegroai 2020-04-16 16:48:19 +03:00
parent c06f72ae3a
commit f7b80a0da2
2 changed files with 42 additions and 15 deletions

View File

@ -2,16 +2,17 @@
import os
import sys
import threading
from copy import deepcopy
from tempfile import mkstemp
import six
from six import BytesIO
import threading
from ..debugging.log import LoggerRoot
from ..config import running_remotely
from .import_bind import PostImportHookPatching
from ..config import running_remotely
from ..debugging.log import LoggerRoot
from ..utilities.resource_monitor import ResourceMonitor
class PatchedMatplotlib:
@ -32,6 +33,8 @@ class PatchedMatplotlib:
_lock_renderer = threading.RLock()
_recursion_guard = {}
_matplot_major_version = 2
_logger_started_reporting = False
_matplotlib_reported_titles = set()
class _PatchWarnings(object):
def __init__(self):
@ -357,7 +360,7 @@ class PatchedMatplotlib:
# remove borders and size, we should let the web take care of that
if plotly_fig:
last_iteration = PatchedMatplotlib._current_task.get_last_iteration()
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
@ -373,6 +376,8 @@ class PatchedMatplotlib:
if not plotly_dict.get('layout'):
plotly_dict['layout'] = {}
plotly_dict['layout']['title'] = title
PatchedMatplotlib._matplotlib_reported_titles.add(title)
reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration)
else:
logger = PatchedMatplotlib._current_task.get_logger()
@ -380,7 +385,7 @@ class PatchedMatplotlib:
# this is actually a failed plot, we should put it under plots:
# currently disabled
if force_save_as_image or not PatchedMatplotlib._support_image_plot:
last_iteration = PatchedMatplotlib._current_task.get_last_iteration()
last_iteration = PatchedMatplotlib._get_last_iteration()
# send the plot as image
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
@ -389,11 +394,12 @@ class PatchedMatplotlib:
title = 'untitled %d' % (PatchedMatplotlib._global_image_counter %
PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series='plot image', local_path=image,
delete_after_upload=True, iteration=last_iteration)
else:
# send the plot as plotly with embedded image
last_iteration = PatchedMatplotlib._current_task.get_last_iteration()
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
@ -401,6 +407,7 @@ class PatchedMatplotlib:
title = 'untitled %d' % (PatchedMatplotlib._global_plot_counter %
PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger._report_image_plot_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True, iteration=last_iteration)
except Exception:
@ -450,6 +457,21 @@ class PatchedMatplotlib:
PatchedMatplotlib.__patched_draw_all_recursion_guard = False
return ret
@staticmethod
def _get_last_iteration():
if PatchedMatplotlib._logger_started_reporting:
return PatchedMatplotlib._current_task.get_last_iteration()
# get the reported plot titles (exclude us)
reported_titles = ResourceMonitor.get_logger_reported_titles(PatchedMatplotlib._current_task)
if not reported_titles:
return 0
# check that this is not only us
if not (set(reported_titles) - PatchedMatplotlib._matplotlib_reported_titles):
return 0
# mark reporting started
PatchedMatplotlib._logger_started_reporting = True
return PatchedMatplotlib._current_task.get_last_iteration()
@staticmethod
def ipython_post_execute_hook():
# noinspection PyBroadException

View File

@ -242,13 +242,18 @@ class ResourceMonitor(object):
return stats
def _check_logger_reported(self):
titles = list(self._task.get_logger()._get_used_title_series().keys())
try:
titles.remove(self._title_machine)
except ValueError:
pass
try:
titles.remove(self._title_gpu)
except ValueError:
pass
titles = self.get_logger_reported_titles(self._task)
return len(titles) > 0
@classmethod
def get_logger_reported_titles(cls, task):
titles = list(task.get_logger()._get_used_title_series().keys())
try:
titles.remove(cls._title_machine)
except ValueError:
pass
try:
titles.remove(cls._title_gpu)
except ValueError:
pass
return titles