Fix matplotlib support

This commit is contained in:
allegroai 2019-08-07 00:05:24 +03:00
parent fe938ca21f
commit 3ab1c261e7
4 changed files with 96 additions and 25 deletions

View File

@ -6,7 +6,9 @@
# 2 seconds per epoch on a K520 GPU. # 2 seconds per epoch on a K520 GPU.
from __future__ import print_function from __future__ import print_function
import io import argparse
import tempfile
import os
import numpy as np import numpy as np
from keras.callbacks import TensorBoard, ModelCheckpoint from keras.callbacks import TensorBoard, ModelCheckpoint
@ -15,8 +17,6 @@ from keras.models import Sequential
from keras.layers.core import Dense, Activation from keras.layers.core import Dense, Activation
from keras.optimizers import RMSprop from keras.optimizers import RMSprop
from keras.utils import np_utils from keras.utils import np_utils
# TODO: test these methods binding
from keras.models import load_model, save_model, model_from_json
import tensorflow as tf import tensorflow as tf
from trains import Task from trains import Task
@ -25,6 +25,7 @@ class TensorBoardImage(TensorBoard):
@staticmethod @staticmethod
def make_image(tensor): def make_image(tensor):
from PIL import Image from PIL import Image
import io
tensor = np.stack((tensor, tensor, tensor), axis=2) tensor = np.stack((tensor, tensor, tensor), axis=2)
height, width, channels = tensor.shape height, width, channels = tensor.shape
image = Image.fromarray(tensor) image = Image.fromarray(tensor)
@ -49,19 +50,17 @@ class TensorBoardImage(TensorBoard):
self.writer.add_summary(summary, epoch) self.writer.add_summary(summary, epoch)
batch_size = 128 parser = argparse.ArgumentParser(description='Keras MNIST Example')
nb_classes = 10 parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)')
nb_epoch = 6 parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)')
args = parser.parse_args()
# the data, shuffled and split between train and test sets # the data, shuffled and split between train and test sets
nb_classes = 10
(X_train, y_train), (X_test, y_test) = mnist.load_data() (X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784) X_train = X_train.reshape(60000, 784).astype('float32')/255.
X_test = X_test.reshape(10000, 784) X_test = X_test.reshape(10000, 784).astype('float32')/255.
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.
X_test /= 255.
print(X_train.shape[0], 'train samples') print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples') print(X_test.shape[0], 'test samples')
@ -91,21 +90,23 @@ model.compile(loss='categorical_crossentropy',
# Connecting TRAINS # Connecting TRAINS
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example') task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
# setting model outputs # Advanced: setting model class enumeration
labels = dict(('digit_%d' % i, i) for i in range(10)) labels = dict(('digit_%d' % i, i) for i in range(10))
task.set_model_label_enumeration(labels) task.set_model_label_enumeration(labels)
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False) output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
# load previous model, if it is there # load previous model, if it is there
try: try:
model.load_weights('/tmp/histogram_example/weight.1.hdf5') model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
except: except:
pass pass
history = model.fit(X_train, Y_train, history = model.fit(X_train, Y_train,
batch_size=batch_size, epochs=nb_epoch, batch_size=args.batch_size, epochs=args.epochs,
callbacks=[board, model_store], callbacks=[board, model_store],
verbose=1, validation_data=(X_test, Y_test)) verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0) score = model.evaluate(X_test, Y_test, verbose=0)

View File

@ -7,7 +7,7 @@ from trains import Task
task = Task.init(project_name='examples', task_name='Manual reporting') task = Task.init(project_name='examples', task_name='Manual reporting')
# example python logger # standard python logging
logging.getLogger().setLevel('DEBUG') logging.getLogger().setLevel('DEBUG')
logging.debug('This is a debug message') logging.debug('This is a debug message')
logging.info('This is an info message') logging.info('This is an info message')
@ -23,7 +23,7 @@ except ImportError:
pass pass
# get TRAINS logger object for any metrics / reports # get TRAINS logger object for any metrics / reports
logger = task.get_logger() logger = Task.current_task().get_logger()
# log text # log text
logger.console("hello") logger.console("hello")
@ -34,7 +34,7 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
# report histogram # report histogram
histogram = np.random.randint(10, size=10) histogram = np.random.randint(10, size=10)
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram) logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram)
# report confusion matrix # report confusion matrix
confusion = np.random.randint(10, size=(10, 10)) confusion = np.random.randint(10, size=(10, 10))

View File

@ -2,6 +2,7 @@ import collections
import json import json
import six import six
import numpy as np
from threading import Thread, Event from threading import Thread, Event
from ..base import InterfaceBase from ..base import InterfaceBase
@ -157,8 +158,15 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param iter: Iteration number :param iter: Iteration number
:type value: int :type value: int
""" """
try:
def default(o):
if isinstance(o, np.int64):
return int(o)
except Exception:
default = None
if isinstance(plot, dict): if isinstance(plot, dict):
plot = json.dumps(plot) plot = json.dumps(plot, default=default)
elif not isinstance(plot, six.string_types): elif not isinstance(plot, six.string_types):
raise ValueError('Plot should be a string or a dict') raise ValueError('Plot should be a string or a dict')
ev = PlotEvent(metric=self._normalize_name(title), ev = PlotEvent(metric=self._normalize_name(title),

View File

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
import os import os
import sys import sys
from tempfile import mkstemp from tempfile import mkstemp
@ -41,7 +43,8 @@ class PatchedMatplotlib:
try: try:
# we support matplotlib version 2.0.0 and above # we support matplotlib version 2.0.0 and above
import matplotlib import matplotlib
if int(matplotlib.__version__.split('.')[0]) < 2: matplot_major_version = int(matplotlib.__version__.split('.')[0])
if matplot_major_version < 2:
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format( 'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__)) matplotlib.__version__))
@ -63,6 +66,7 @@ class PatchedMatplotlib:
PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show PatchedMatplotlib._patched_original_figure = figure.Figure.show
plt.show = PatchedMatplotlib.patched_show plt.show = PatchedMatplotlib.patched_show
figure.Figure.show = PatchedMatplotlib.patched_figure_show figure.Figure.show = PatchedMatplotlib.patched_figure_show
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
@ -111,13 +115,26 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_figure_show(self, *args, **kw): def patched_figure_show(self, *args, **kw):
if hasattr(self, '_trains_show'):
# flag will be cleared when calling clf() (object will be replaced)
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
try:
self._trains_show = True
except Exception:
pass
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self) PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw) ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
return ret return ret
@staticmethod @staticmethod
def patched_show(*args, **kw): def patched_show(*args, **kw):
PatchedMatplotlib._report_figure() # noinspection PyBroadException
try:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
for figure in figures:
PatchedMatplotlib._report_figure(stored_figure=figure)
except Exception:
pass
ret = PatchedMatplotlib._patched_original_plot(*args, **kw) ret = PatchedMatplotlib._patched_original_plot(*args, **kw)
if PatchedMatplotlib._current_task and running_remotely(): if PatchedMatplotlib._current_task and running_remotely():
# clear the current plot, because no one else will # clear the current plot, because no one else will
@ -155,10 +172,17 @@ class PatchedMatplotlib:
else: else:
mpl_fig = specific_fig mpl_fig = specific_fig
# mark as processed, so nested calls to figure.show will do nothing
try:
mpl_fig._trains_show = True
except Exception:
pass
# convert to plotly # convert to plotly
image = None image = None
plotly_fig = None plotly_fig = None
image_format = 'jpeg' image_format = 'jpeg'
fig_dpi = 300
if not force_save_as_image: if not force_save_as_image:
image_format = 'svg' image_format = 'svg'
# noinspection PyBroadException # noinspection PyBroadException
@ -168,12 +192,39 @@ class PatchedMatplotlib:
if matplotlylib: if matplotlylib:
renderer = matplotlylib.PlotlyRenderer() renderer = matplotlylib.PlotlyRenderer()
matplotlylib.Exporter(renderer, close_mpl=False).run(fig) matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
if x_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in x_ticks]
except:
try:
renderer.plotly_fig['layout']['xaxis1'].update({
'ticktext': [t.get_text() for t in x_ticks],
'tickvals': [t.get_position()[0] for t in x_ticks],
})
except:
pass
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
if y_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in y_ticks]
except:
try:
renderer.plotly_fig['layout']['yaxis1'].update({
'ticktext': [t.get_text() for t in y_ticks],
'tickvals': [t.get_position()[1] for t in y_ticks],
})
except:
pass
return renderer.plotly_fig return renderer.plotly_fig
plotly_fig = our_mpl_to_plotly(mpl_fig) plotly_fig = our_mpl_to_plotly(mpl_fig)
except Exception as ex: except Exception as ex:
# this was an image, change format to png # this was an image, change format to png
image_format = 'jpeg' if 'selfie' in str(ex) else 'png' image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
fig_dpi = 300
# plotly could not serialize the plot, we should convert to image # plotly could not serialize the plot, we should convert to image
if not plotly_fig: if not plotly_fig:
@ -183,13 +234,13 @@ class PatchedMatplotlib:
# first try SVG if we fail then fallback to png # first try SVG if we fail then fallback to png
buffer_ = BytesIO() buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0) buffer_.seek(0)
except Exception: except Exception:
image_format = 'png' image_format = 'png'
buffer_ = BytesIO() buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0) buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format) fd, image = mkstemp(suffix='.'+image_format)
os.write(fd, buffer_.read()) os.write(fd, buffer_.read())
@ -253,6 +304,17 @@ class PatchedMatplotlib:
return return
@staticmethod
def _get_output_figures(stored_figure, all_figures):
try:
from matplotlib import _pylab_helpers
if all_figures:
return list(_pylab_helpers.Gcf.figs.values())
else:
return [stored_figure] or [_pylab_helpers.Gcf.get_active()]
except Exception:
return []
@staticmethod @staticmethod
def __patched_draw_all(*args, **kwargs): def __patched_draw_all(*args, **kwargs):
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard