mirror of
https://github.com/clearml/clearml
synced 2025-04-03 04:21:03 +00:00
Fix matplotlib support
This commit is contained in:
parent
fe938ca21f
commit
3ab1c261e7
@ -6,7 +6,9 @@
|
||||
# 2 seconds per epoch on a K520 GPU.
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import argparse
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
@ -15,8 +17,6 @@ from keras.models import Sequential
|
||||
from keras.layers.core import Dense, Activation
|
||||
from keras.optimizers import RMSprop
|
||||
from keras.utils import np_utils
|
||||
# TODO: test these methods binding
|
||||
from keras.models import load_model, save_model, model_from_json
|
||||
import tensorflow as tf
|
||||
from trains import Task
|
||||
|
||||
@ -25,6 +25,7 @@ class TensorBoardImage(TensorBoard):
|
||||
@staticmethod
|
||||
def make_image(tensor):
|
||||
from PIL import Image
|
||||
import io
|
||||
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
||||
height, width, channels = tensor.shape
|
||||
image = Image.fromarray(tensor)
|
||||
@ -49,19 +50,17 @@ class TensorBoardImage(TensorBoard):
|
||||
self.writer.add_summary(summary, epoch)
|
||||
|
||||
|
||||
batch_size = 128
|
||||
nb_classes = 10
|
||||
nb_epoch = 6
|
||||
parser = argparse.ArgumentParser(description='Keras MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# the data, shuffled and split between train and test sets
|
||||
nb_classes = 10
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
|
||||
X_train = X_train.reshape(60000, 784)
|
||||
X_test = X_test.reshape(10000, 784)
|
||||
X_train = X_train.astype('float32')
|
||||
X_test = X_test.astype('float32')
|
||||
X_train /= 255.
|
||||
X_test /= 255.
|
||||
X_train = X_train.reshape(60000, 784).astype('float32')/255.
|
||||
X_test = X_test.reshape(10000, 784).astype('float32')/255.
|
||||
print(X_train.shape[0], 'train samples')
|
||||
print(X_test.shape[0], 'test samples')
|
||||
|
||||
@ -91,21 +90,23 @@ model.compile(loss='categorical_crossentropy',
|
||||
|
||||
# Connecting TRAINS
|
||||
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
|
||||
# setting model outputs
|
||||
# Advanced: setting model class enumeration
|
||||
labels = dict(('digit_%d' % i, i) for i in range(10))
|
||||
task.set_model_label_enumeration(labels)
|
||||
|
||||
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False)
|
||||
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
|
||||
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
|
||||
|
||||
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
|
||||
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
|
||||
|
||||
# load previous model, if it is there
|
||||
try:
|
||||
model.load_weights('/tmp/histogram_example/weight.1.hdf5')
|
||||
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
|
||||
except:
|
||||
pass
|
||||
|
||||
history = model.fit(X_train, Y_train,
|
||||
batch_size=batch_size, epochs=nb_epoch,
|
||||
batch_size=args.batch_size, epochs=args.epochs,
|
||||
callbacks=[board, model_store],
|
||||
verbose=1, validation_data=(X_test, Y_test))
|
||||
score = model.evaluate(X_test, Y_test, verbose=0)
|
||||
|
@ -7,7 +7,7 @@ from trains import Task
|
||||
|
||||
task = Task.init(project_name='examples', task_name='Manual reporting')
|
||||
|
||||
# example python logger
|
||||
# standard python logging
|
||||
logging.getLogger().setLevel('DEBUG')
|
||||
logging.debug('This is a debug message')
|
||||
logging.info('This is an info message')
|
||||
@ -23,7 +23,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
# get TRAINS logger object for any metrics / reports
|
||||
logger = task.get_logger()
|
||||
logger = Task.current_task().get_logger()
|
||||
|
||||
# log text
|
||||
logger.console("hello")
|
||||
@ -34,7 +34,7 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
|
||||
|
||||
# report histogram
|
||||
histogram = np.random.randint(10, size=10)
|
||||
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram)
|
||||
logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram)
|
||||
|
||||
# report confusion matrix
|
||||
confusion = np.random.randint(10, size=(10, 10))
|
||||
|
@ -2,6 +2,7 @@ import collections
|
||||
import json
|
||||
|
||||
import six
|
||||
import numpy as np
|
||||
from threading import Thread, Event
|
||||
|
||||
from ..base import InterfaceBase
|
||||
@ -157,8 +158,15 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
"""
|
||||
try:
|
||||
def default(o):
|
||||
if isinstance(o, np.int64):
|
||||
return int(o)
|
||||
except Exception:
|
||||
default = None
|
||||
|
||||
if isinstance(plot, dict):
|
||||
plot = json.dumps(plot)
|
||||
plot = json.dumps(plot, default=default)
|
||||
elif not isinstance(plot, six.string_types):
|
||||
raise ValueError('Plot should be a string or a dict')
|
||||
ev = PlotEvent(metric=self._normalize_name(title),
|
||||
|
@ -1,3 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
from tempfile import mkstemp
|
||||
@ -41,7 +43,8 @@ class PatchedMatplotlib:
|
||||
try:
|
||||
# we support matplotlib version 2.0.0 and above
|
||||
import matplotlib
|
||||
if int(matplotlib.__version__.split('.')[0]) < 2:
|
||||
matplot_major_version = int(matplotlib.__version__.split('.')[0])
|
||||
if matplot_major_version < 2:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
'matplotlib binding supports version 2.0 and above, found version {}'.format(
|
||||
matplotlib.__version__))
|
||||
@ -63,6 +66,7 @@ class PatchedMatplotlib:
|
||||
PatchedMatplotlib._patched_original_plot = plt.show
|
||||
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||
PatchedMatplotlib._patched_original_figure = figure.Figure.show
|
||||
|
||||
plt.show = PatchedMatplotlib.patched_show
|
||||
figure.Figure.show = PatchedMatplotlib.patched_figure_show
|
||||
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||
@ -111,13 +115,26 @@ class PatchedMatplotlib:
|
||||
|
||||
@staticmethod
|
||||
def patched_figure_show(self, *args, **kw):
|
||||
if hasattr(self, '_trains_show'):
|
||||
# flag will be cleared when calling clf() (object will be replaced)
|
||||
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||
try:
|
||||
self._trains_show = True
|
||||
except Exception:
|
||||
pass
|
||||
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
|
||||
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def patched_show(*args, **kw):
|
||||
PatchedMatplotlib._report_figure()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
|
||||
for figure in figures:
|
||||
PatchedMatplotlib._report_figure(stored_figure=figure)
|
||||
except Exception:
|
||||
pass
|
||||
ret = PatchedMatplotlib._patched_original_plot(*args, **kw)
|
||||
if PatchedMatplotlib._current_task and running_remotely():
|
||||
# clear the current plot, because no one else will
|
||||
@ -155,10 +172,17 @@ class PatchedMatplotlib:
|
||||
else:
|
||||
mpl_fig = specific_fig
|
||||
|
||||
# mark as processed, so nested calls to figure.show will do nothing
|
||||
try:
|
||||
mpl_fig._trains_show = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# convert to plotly
|
||||
image = None
|
||||
plotly_fig = None
|
||||
image_format = 'jpeg'
|
||||
fig_dpi = 300
|
||||
if not force_save_as_image:
|
||||
image_format = 'svg'
|
||||
# noinspection PyBroadException
|
||||
@ -168,12 +192,39 @@ class PatchedMatplotlib:
|
||||
if matplotlylib:
|
||||
renderer = matplotlylib.PlotlyRenderer()
|
||||
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
|
||||
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
|
||||
if x_ticks:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in x_ticks]
|
||||
except:
|
||||
try:
|
||||
renderer.plotly_fig['layout']['xaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in x_ticks],
|
||||
'tickvals': [t.get_position()[0] for t in x_ticks],
|
||||
})
|
||||
except:
|
||||
pass
|
||||
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
|
||||
if y_ticks:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in y_ticks]
|
||||
except:
|
||||
try:
|
||||
renderer.plotly_fig['layout']['yaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in y_ticks],
|
||||
'tickvals': [t.get_position()[1] for t in y_ticks],
|
||||
})
|
||||
except:
|
||||
pass
|
||||
return renderer.plotly_fig
|
||||
|
||||
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
||||
except Exception as ex:
|
||||
# this was an image, change format to png
|
||||
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
|
||||
fig_dpi = 300
|
||||
|
||||
# plotly could not serialize the plot, we should convert to image
|
||||
if not plotly_fig:
|
||||
@ -183,13 +234,13 @@ class PatchedMatplotlib:
|
||||
# first try SVG if we fail then fallback to png
|
||||
buffer_ = BytesIO()
|
||||
a_plt = specific_fig if specific_fig is not None else plt
|
||||
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
buffer_.seek(0)
|
||||
except Exception:
|
||||
image_format = 'png'
|
||||
buffer_ = BytesIO()
|
||||
a_plt = specific_fig if specific_fig is not None else plt
|
||||
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
buffer_.seek(0)
|
||||
fd, image = mkstemp(suffix='.'+image_format)
|
||||
os.write(fd, buffer_.read())
|
||||
@ -253,6 +304,17 @@ class PatchedMatplotlib:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _get_output_figures(stored_figure, all_figures):
|
||||
try:
|
||||
from matplotlib import _pylab_helpers
|
||||
if all_figures:
|
||||
return list(_pylab_helpers.Gcf.figs.values())
|
||||
else:
|
||||
return [stored_figure] or [_pylab_helpers.Gcf.get_active()]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def __patched_draw_all(*args, **kwargs):
|
||||
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
|
||||
|
Loading…
Reference in New Issue
Block a user