Fix matplotlib support

This commit is contained in:
allegroai 2019-08-07 00:05:24 +03:00
parent fe938ca21f
commit 3ab1c261e7
4 changed files with 96 additions and 25 deletions

View File

@ -6,7 +6,9 @@
# 2 seconds per epoch on a K520 GPU.
from __future__ import print_function
import io
import argparse
import tempfile
import os
import numpy as np
from keras.callbacks import TensorBoard, ModelCheckpoint
@ -15,8 +17,6 @@ from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.optimizers import RMSprop
from keras.utils import np_utils
# TODO: test these methods binding
from keras.models import load_model, save_model, model_from_json
import tensorflow as tf
from trains import Task
@ -25,6 +25,7 @@ class TensorBoardImage(TensorBoard):
@staticmethod
def make_image(tensor):
from PIL import Image
import io
tensor = np.stack((tensor, tensor, tensor), axis=2)
height, width, channels = tensor.shape
image = Image.fromarray(tensor)
@ -49,19 +50,17 @@ class TensorBoardImage(TensorBoard):
self.writer.add_summary(summary, epoch)
batch_size = 128
nb_classes = 10
nb_epoch = 6
parser = argparse.ArgumentParser(description='Keras MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)')
args = parser.parse_args()
# the data, shuffled and split between train and test sets
nb_classes = 10
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.
X_test /= 255.
X_train = X_train.reshape(60000, 784).astype('float32')/255.
X_test = X_test.reshape(10000, 784).astype('float32')/255.
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
@ -91,21 +90,23 @@ model.compile(loss='categorical_crossentropy',
# Connecting TRAINS
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
# setting model outputs
# Advanced: setting model class enumeration
labels = dict(('digit_%d' % i, i) for i in range(10))
task.set_model_label_enumeration(labels)
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False)
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
# load previous model, if it is there
try:
model.load_weights('/tmp/histogram_example/weight.1.hdf5')
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
except:
pass
history = model.fit(X_train, Y_train,
batch_size=batch_size, epochs=nb_epoch,
batch_size=args.batch_size, epochs=args.epochs,
callbacks=[board, model_store],
verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)

View File

@ -7,7 +7,7 @@ from trains import Task
task = Task.init(project_name='examples', task_name='Manual reporting')
# example python logger
# standard python logging
logging.getLogger().setLevel('DEBUG')
logging.debug('This is a debug message')
logging.info('This is an info message')
@ -23,7 +23,7 @@ except ImportError:
pass
# get TRAINS logger object for any metrics / reports
logger = task.get_logger()
logger = Task.current_task().get_logger()
# log text
logger.console("hello")
@ -34,7 +34,7 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
# report histogram
histogram = np.random.randint(10, size=10)
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram)
logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram)
# report confusion matrix
confusion = np.random.randint(10, size=(10, 10))

View File

@ -2,6 +2,7 @@ import collections
import json
import six
import numpy as np
from threading import Thread, Event
from ..base import InterfaceBase
@ -157,8 +158,15 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param iter: Iteration number
:type value: int
"""
try:
def default(o):
if isinstance(o, np.int64):
return int(o)
except Exception:
default = None
if isinstance(plot, dict):
plot = json.dumps(plot)
plot = json.dumps(plot, default=default)
elif not isinstance(plot, six.string_types):
raise ValueError('Plot should be a string or a dict')
ev = PlotEvent(metric=self._normalize_name(title),

View File

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
import os
import sys
from tempfile import mkstemp
@ -41,7 +43,8 @@ class PatchedMatplotlib:
try:
# we support matplotlib version 2.0.0 and above
import matplotlib
if int(matplotlib.__version__.split('.')[0]) < 2:
matplot_major_version = int(matplotlib.__version__.split('.')[0])
if matplot_major_version < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__))
@ -63,6 +66,7 @@ class PatchedMatplotlib:
PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show
plt.show = PatchedMatplotlib.patched_show
figure.Figure.show = PatchedMatplotlib.patched_figure_show
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
@ -111,13 +115,26 @@ class PatchedMatplotlib:
@staticmethod
def patched_figure_show(self, *args, **kw):
if hasattr(self, '_trains_show'):
# flag will be cleared when calling clf() (object will be replaced)
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
try:
self._trains_show = True
except Exception:
pass
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
return ret
@staticmethod
def patched_show(*args, **kw):
PatchedMatplotlib._report_figure()
# noinspection PyBroadException
try:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
for figure in figures:
PatchedMatplotlib._report_figure(stored_figure=figure)
except Exception:
pass
ret = PatchedMatplotlib._patched_original_plot(*args, **kw)
if PatchedMatplotlib._current_task and running_remotely():
# clear the current plot, because no one else will
@ -155,10 +172,17 @@ class PatchedMatplotlib:
else:
mpl_fig = specific_fig
# mark as processed, so nested calls to figure.show will do nothing
try:
mpl_fig._trains_show = True
except Exception:
pass
# convert to plotly
image = None
plotly_fig = None
image_format = 'jpeg'
fig_dpi = 300
if not force_save_as_image:
image_format = 'svg'
# noinspection PyBroadException
@ -168,12 +192,39 @@ class PatchedMatplotlib:
if matplotlylib:
renderer = matplotlylib.PlotlyRenderer()
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
if x_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in x_ticks]
except:
try:
renderer.plotly_fig['layout']['xaxis1'].update({
'ticktext': [t.get_text() for t in x_ticks],
'tickvals': [t.get_position()[0] for t in x_ticks],
})
except:
pass
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
if y_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in y_ticks]
except:
try:
renderer.plotly_fig['layout']['yaxis1'].update({
'ticktext': [t.get_text() for t in y_ticks],
'tickvals': [t.get_position()[1] for t in y_ticks],
})
except:
pass
return renderer.plotly_fig
plotly_fig = our_mpl_to_plotly(mpl_fig)
except Exception as ex:
# this was an image, change format to png
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
fig_dpi = 300
# plotly could not serialize the plot, we should convert to image
if not plotly_fig:
@ -183,13 +234,13 @@ class PatchedMatplotlib:
# first try SVG if we fail then fallback to png
buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0)
except Exception:
image_format = 'png'
buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format)
os.write(fd, buffer_.read())
@ -253,6 +304,17 @@ class PatchedMatplotlib:
return
@staticmethod
def _get_output_figures(stored_figure, all_figures):
try:
from matplotlib import _pylab_helpers
if all_figures:
return list(_pylab_helpers.Gcf.figs.values())
else:
return [stored_figure] or [_pylab_helpers.Gcf.get_active()]
except Exception:
return []
@staticmethod
def __patched_draw_all(*args, **kwargs):
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard