matplotlib with no logger should not use the last iteration reported by the resource monitor

This commit is contained in:
allegroai 2020-04-16 16:48:19 +03:00
parent c06f72ae3a
commit f7b80a0da2
2 changed files with 42 additions and 15 deletions

View File

@ -2,16 +2,17 @@
import os import os
import sys import sys
import threading
from copy import deepcopy from copy import deepcopy
from tempfile import mkstemp from tempfile import mkstemp
import six import six
from six import BytesIO from six import BytesIO
import threading
from ..debugging.log import LoggerRoot
from ..config import running_remotely
from .import_bind import PostImportHookPatching from .import_bind import PostImportHookPatching
from ..config import running_remotely
from ..debugging.log import LoggerRoot
from ..utilities.resource_monitor import ResourceMonitor
class PatchedMatplotlib: class PatchedMatplotlib:
@ -32,6 +33,8 @@ class PatchedMatplotlib:
_lock_renderer = threading.RLock() _lock_renderer = threading.RLock()
_recursion_guard = {} _recursion_guard = {}
_matplot_major_version = 2 _matplot_major_version = 2
_logger_started_reporting = False
_matplotlib_reported_titles = set()
class _PatchWarnings(object): class _PatchWarnings(object):
def __init__(self): def __init__(self):
@ -357,7 +360,7 @@ class PatchedMatplotlib:
# remove borders and size, we should let the web take care of that # remove borders and size, we should let the web take care of that
if plotly_fig: if plotly_fig:
last_iteration = PatchedMatplotlib._current_task.get_last_iteration() last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title: if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
@ -373,6 +376,8 @@ class PatchedMatplotlib:
if not plotly_dict.get('layout'): if not plotly_dict.get('layout'):
plotly_dict['layout'] = {} plotly_dict['layout'] = {}
plotly_dict['layout']['title'] = title plotly_dict['layout']['title'] = title
PatchedMatplotlib._matplotlib_reported_titles.add(title)
reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration) reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration)
else: else:
logger = PatchedMatplotlib._current_task.get_logger() logger = PatchedMatplotlib._current_task.get_logger()
@ -380,7 +385,7 @@ class PatchedMatplotlib:
# this is actually a failed plot, we should put it under plots: # this is actually a failed plot, we should put it under plots:
# currently disabled # currently disabled
if force_save_as_image or not PatchedMatplotlib._support_image_plot: if force_save_as_image or not PatchedMatplotlib._support_image_plot:
last_iteration = PatchedMatplotlib._current_task.get_last_iteration() last_iteration = PatchedMatplotlib._get_last_iteration()
# send the plot as image # send the plot as image
if plot_title: if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
@ -389,11 +394,12 @@ class PatchedMatplotlib:
title = 'untitled %d' % (PatchedMatplotlib._global_image_counter % title = 'untitled %d' % (PatchedMatplotlib._global_image_counter %
PatchedMatplotlib._global_image_counter_limit) PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series='plot image', local_path=image, logger.report_image(title=title, series='plot image', local_path=image,
delete_after_upload=True, iteration=last_iteration) delete_after_upload=True, iteration=last_iteration)
else: else:
# send the plot as plotly with embedded image # send the plot as plotly with embedded image
last_iteration = PatchedMatplotlib._current_task.get_last_iteration() last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title: if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
@ -401,6 +407,7 @@ class PatchedMatplotlib:
title = 'untitled %d' % (PatchedMatplotlib._global_plot_counter % title = 'untitled %d' % (PatchedMatplotlib._global_plot_counter %
PatchedMatplotlib._global_image_counter_limit) PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger._report_image_plot_and_upload(title=title, series='plot image', path=image, logger._report_image_plot_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True, iteration=last_iteration) delete_after_upload=True, iteration=last_iteration)
except Exception: except Exception:
@ -450,6 +457,21 @@ class PatchedMatplotlib:
PatchedMatplotlib.__patched_draw_all_recursion_guard = False PatchedMatplotlib.__patched_draw_all_recursion_guard = False
return ret return ret
@staticmethod
def _get_last_iteration():
if PatchedMatplotlib._logger_started_reporting:
return PatchedMatplotlib._current_task.get_last_iteration()
# get the reported plot titles (exclude us)
reported_titles = ResourceMonitor.get_logger_reported_titles(PatchedMatplotlib._current_task)
if not reported_titles:
return 0
# check that this is not only us
if not (set(reported_titles) - PatchedMatplotlib._matplotlib_reported_titles):
return 0
# mark reporting started
PatchedMatplotlib._logger_started_reporting = True
return PatchedMatplotlib._current_task.get_last_iteration()
@staticmethod @staticmethod
def ipython_post_execute_hook(): def ipython_post_execute_hook():
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -242,13 +242,18 @@ class ResourceMonitor(object):
return stats return stats
def _check_logger_reported(self): def _check_logger_reported(self):
titles = list(self._task.get_logger()._get_used_title_series().keys()) titles = self.get_logger_reported_titles(self._task)
try:
titles.remove(self._title_machine)
except ValueError:
pass
try:
titles.remove(self._title_gpu)
except ValueError:
pass
return len(titles) > 0 return len(titles) > 0
@classmethod
def get_logger_reported_titles(cls, task):
titles = list(task.get_logger()._get_used_title_series().keys())
try:
titles.remove(cls._title_machine)
except ValueError:
pass
try:
titles.remove(cls._title_gpu)
except ValueError:
pass
return titles