mirror of
https://github.com/clearml/clearml
synced 2025-04-04 21:03:00 +00:00
Fix matplotlib support
This commit is contained in:
parent
fe938ca21f
commit
3ab1c261e7
@ -6,7 +6,9 @@
|
|||||||
# 2 seconds per epoch on a K520 GPU.
|
# 2 seconds per epoch on a K520 GPU.
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import io
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||||
@ -15,8 +17,6 @@ from keras.models import Sequential
|
|||||||
from keras.layers.core import Dense, Activation
|
from keras.layers.core import Dense, Activation
|
||||||
from keras.optimizers import RMSprop
|
from keras.optimizers import RMSprop
|
||||||
from keras.utils import np_utils
|
from keras.utils import np_utils
|
||||||
# TODO: test these methods binding
|
|
||||||
from keras.models import load_model, save_model, model_from_json
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from trains import Task
|
from trains import Task
|
||||||
|
|
||||||
@ -25,6 +25,7 @@ class TensorBoardImage(TensorBoard):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def make_image(tensor):
|
def make_image(tensor):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import io
|
||||||
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
||||||
height, width, channels = tensor.shape
|
height, width, channels = tensor.shape
|
||||||
image = Image.fromarray(tensor)
|
image = Image.fromarray(tensor)
|
||||||
@ -49,19 +50,17 @@ class TensorBoardImage(TensorBoard):
|
|||||||
self.writer.add_summary(summary, epoch)
|
self.writer.add_summary(summary, epoch)
|
||||||
|
|
||||||
|
|
||||||
batch_size = 128
|
parser = argparse.ArgumentParser(description='Keras MNIST Example')
|
||||||
nb_classes = 10
|
parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)')
|
||||||
nb_epoch = 6
|
parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# the data, shuffled and split between train and test sets
|
# the data, shuffled and split between train and test sets
|
||||||
|
nb_classes = 10
|
||||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||||
|
|
||||||
X_train = X_train.reshape(60000, 784)
|
X_train = X_train.reshape(60000, 784).astype('float32')/255.
|
||||||
X_test = X_test.reshape(10000, 784)
|
X_test = X_test.reshape(10000, 784).astype('float32')/255.
|
||||||
X_train = X_train.astype('float32')
|
|
||||||
X_test = X_test.astype('float32')
|
|
||||||
X_train /= 255.
|
|
||||||
X_test /= 255.
|
|
||||||
print(X_train.shape[0], 'train samples')
|
print(X_train.shape[0], 'train samples')
|
||||||
print(X_test.shape[0], 'test samples')
|
print(X_test.shape[0], 'test samples')
|
||||||
|
|
||||||
@ -91,21 +90,23 @@ model.compile(loss='categorical_crossentropy',
|
|||||||
|
|
||||||
# Connecting TRAINS
|
# Connecting TRAINS
|
||||||
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
|
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
|
||||||
# setting model outputs
|
# Advanced: setting model class enumeration
|
||||||
labels = dict(('digit_%d' % i, i) for i in range(10))
|
labels = dict(('digit_%d' % i, i) for i in range(10))
|
||||||
task.set_model_label_enumeration(labels)
|
task.set_model_label_enumeration(labels)
|
||||||
|
|
||||||
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False)
|
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
|
||||||
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
|
|
||||||
|
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
|
||||||
|
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
|
||||||
|
|
||||||
# load previous model, if it is there
|
# load previous model, if it is there
|
||||||
try:
|
try:
|
||||||
model.load_weights('/tmp/histogram_example/weight.1.hdf5')
|
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
history = model.fit(X_train, Y_train,
|
history = model.fit(X_train, Y_train,
|
||||||
batch_size=batch_size, epochs=nb_epoch,
|
batch_size=args.batch_size, epochs=args.epochs,
|
||||||
callbacks=[board, model_store],
|
callbacks=[board, model_store],
|
||||||
verbose=1, validation_data=(X_test, Y_test))
|
verbose=1, validation_data=(X_test, Y_test))
|
||||||
score = model.evaluate(X_test, Y_test, verbose=0)
|
score = model.evaluate(X_test, Y_test, verbose=0)
|
||||||
|
@ -7,7 +7,7 @@ from trains import Task
|
|||||||
|
|
||||||
task = Task.init(project_name='examples', task_name='Manual reporting')
|
task = Task.init(project_name='examples', task_name='Manual reporting')
|
||||||
|
|
||||||
# example python logger
|
# standard python logging
|
||||||
logging.getLogger().setLevel('DEBUG')
|
logging.getLogger().setLevel('DEBUG')
|
||||||
logging.debug('This is a debug message')
|
logging.debug('This is a debug message')
|
||||||
logging.info('This is an info message')
|
logging.info('This is an info message')
|
||||||
@ -23,7 +23,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# get TRAINS logger object for any metrics / reports
|
# get TRAINS logger object for any metrics / reports
|
||||||
logger = task.get_logger()
|
logger = Task.current_task().get_logger()
|
||||||
|
|
||||||
# log text
|
# log text
|
||||||
logger.console("hello")
|
logger.console("hello")
|
||||||
@ -34,7 +34,7 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
|
|||||||
|
|
||||||
# report histogram
|
# report histogram
|
||||||
histogram = np.random.randint(10, size=10)
|
histogram = np.random.randint(10, size=10)
|
||||||
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram)
|
logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram)
|
||||||
|
|
||||||
# report confusion matrix
|
# report confusion matrix
|
||||||
confusion = np.random.randint(10, size=(10, 10))
|
confusion = np.random.randint(10, size=(10, 10))
|
||||||
|
@ -2,6 +2,7 @@ import collections
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
import numpy as np
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
|
|
||||||
from ..base import InterfaceBase
|
from ..base import InterfaceBase
|
||||||
@ -157,8 +158,15 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
:param iter: Iteration number
|
:param iter: Iteration number
|
||||||
:type value: int
|
:type value: int
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
def default(o):
|
||||||
|
if isinstance(o, np.int64):
|
||||||
|
return int(o)
|
||||||
|
except Exception:
|
||||||
|
default = None
|
||||||
|
|
||||||
if isinstance(plot, dict):
|
if isinstance(plot, dict):
|
||||||
plot = json.dumps(plot)
|
plot = json.dumps(plot, default=default)
|
||||||
elif not isinstance(plot, six.string_types):
|
elif not isinstance(plot, six.string_types):
|
||||||
raise ValueError('Plot should be a string or a dict')
|
raise ValueError('Plot should be a string or a dict')
|
||||||
ev = PlotEvent(metric=self._normalize_name(title),
|
ev = PlotEvent(metric=self._normalize_name(title),
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
@ -41,7 +43,8 @@ class PatchedMatplotlib:
|
|||||||
try:
|
try:
|
||||||
# we support matplotlib version 2.0.0 and above
|
# we support matplotlib version 2.0.0 and above
|
||||||
import matplotlib
|
import matplotlib
|
||||||
if int(matplotlib.__version__.split('.')[0]) < 2:
|
matplot_major_version = int(matplotlib.__version__.split('.')[0])
|
||||||
|
if matplot_major_version < 2:
|
||||||
LoggerRoot.get_base_logger().warning(
|
LoggerRoot.get_base_logger().warning(
|
||||||
'matplotlib binding supports version 2.0 and above, found version {}'.format(
|
'matplotlib binding supports version 2.0 and above, found version {}'.format(
|
||||||
matplotlib.__version__))
|
matplotlib.__version__))
|
||||||
@ -63,6 +66,7 @@ class PatchedMatplotlib:
|
|||||||
PatchedMatplotlib._patched_original_plot = plt.show
|
PatchedMatplotlib._patched_original_plot = plt.show
|
||||||
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||||
PatchedMatplotlib._patched_original_figure = figure.Figure.show
|
PatchedMatplotlib._patched_original_figure = figure.Figure.show
|
||||||
|
|
||||||
plt.show = PatchedMatplotlib.patched_show
|
plt.show = PatchedMatplotlib.patched_show
|
||||||
figure.Figure.show = PatchedMatplotlib.patched_figure_show
|
figure.Figure.show = PatchedMatplotlib.patched_figure_show
|
||||||
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||||
@ -111,13 +115,26 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_figure_show(self, *args, **kw):
|
def patched_figure_show(self, *args, **kw):
|
||||||
|
if hasattr(self, '_trains_show'):
|
||||||
|
# flag will be cleared when calling clf() (object will be replaced)
|
||||||
|
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||||
|
try:
|
||||||
|
self._trains_show = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
|
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
|
||||||
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_show(*args, **kw):
|
def patched_show(*args, **kw):
|
||||||
PatchedMatplotlib._report_figure()
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
|
||||||
|
for figure in figures:
|
||||||
|
PatchedMatplotlib._report_figure(stored_figure=figure)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
ret = PatchedMatplotlib._patched_original_plot(*args, **kw)
|
ret = PatchedMatplotlib._patched_original_plot(*args, **kw)
|
||||||
if PatchedMatplotlib._current_task and running_remotely():
|
if PatchedMatplotlib._current_task and running_remotely():
|
||||||
# clear the current plot, because no one else will
|
# clear the current plot, because no one else will
|
||||||
@ -155,10 +172,17 @@ class PatchedMatplotlib:
|
|||||||
else:
|
else:
|
||||||
mpl_fig = specific_fig
|
mpl_fig = specific_fig
|
||||||
|
|
||||||
|
# mark as processed, so nested calls to figure.show will do nothing
|
||||||
|
try:
|
||||||
|
mpl_fig._trains_show = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# convert to plotly
|
# convert to plotly
|
||||||
image = None
|
image = None
|
||||||
plotly_fig = None
|
plotly_fig = None
|
||||||
image_format = 'jpeg'
|
image_format = 'jpeg'
|
||||||
|
fig_dpi = 300
|
||||||
if not force_save_as_image:
|
if not force_save_as_image:
|
||||||
image_format = 'svg'
|
image_format = 'svg'
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -168,12 +192,39 @@ class PatchedMatplotlib:
|
|||||||
if matplotlylib:
|
if matplotlylib:
|
||||||
renderer = matplotlylib.PlotlyRenderer()
|
renderer = matplotlylib.PlotlyRenderer()
|
||||||
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
|
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
|
||||||
|
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
|
||||||
|
if x_ticks:
|
||||||
|
try:
|
||||||
|
# check if all values can be cast to float
|
||||||
|
values = [float(t.get_text().replace('−', '-')) for t in x_ticks]
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
renderer.plotly_fig['layout']['xaxis1'].update({
|
||||||
|
'ticktext': [t.get_text() for t in x_ticks],
|
||||||
|
'tickvals': [t.get_position()[0] for t in x_ticks],
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
|
||||||
|
if y_ticks:
|
||||||
|
try:
|
||||||
|
# check if all values can be cast to float
|
||||||
|
values = [float(t.get_text().replace('−', '-')) for t in y_ticks]
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
renderer.plotly_fig['layout']['yaxis1'].update({
|
||||||
|
'ticktext': [t.get_text() for t in y_ticks],
|
||||||
|
'tickvals': [t.get_position()[1] for t in y_ticks],
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
return renderer.plotly_fig
|
return renderer.plotly_fig
|
||||||
|
|
||||||
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
# this was an image, change format to png
|
# this was an image, change format to png
|
||||||
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
|
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
|
||||||
|
fig_dpi = 300
|
||||||
|
|
||||||
# plotly could not serialize the plot, we should convert to image
|
# plotly could not serialize the plot, we should convert to image
|
||||||
if not plotly_fig:
|
if not plotly_fig:
|
||||||
@ -183,13 +234,13 @@ class PatchedMatplotlib:
|
|||||||
# first try SVG if we fail then fallback to png
|
# first try SVG if we fail then fallback to png
|
||||||
buffer_ = BytesIO()
|
buffer_ = BytesIO()
|
||||||
a_plt = specific_fig if specific_fig is not None else plt
|
a_plt = specific_fig if specific_fig is not None else plt
|
||||||
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||||
buffer_.seek(0)
|
buffer_.seek(0)
|
||||||
except Exception:
|
except Exception:
|
||||||
image_format = 'png'
|
image_format = 'png'
|
||||||
buffer_ = BytesIO()
|
buffer_ = BytesIO()
|
||||||
a_plt = specific_fig if specific_fig is not None else plt
|
a_plt = specific_fig if specific_fig is not None else plt
|
||||||
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||||
buffer_.seek(0)
|
buffer_.seek(0)
|
||||||
fd, image = mkstemp(suffix='.'+image_format)
|
fd, image = mkstemp(suffix='.'+image_format)
|
||||||
os.write(fd, buffer_.read())
|
os.write(fd, buffer_.read())
|
||||||
@ -253,6 +304,17 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_output_figures(stored_figure, all_figures):
|
||||||
|
try:
|
||||||
|
from matplotlib import _pylab_helpers
|
||||||
|
if all_figures:
|
||||||
|
return list(_pylab_helpers.Gcf.figs.values())
|
||||||
|
else:
|
||||||
|
return [stored_figure] or [_pylab_helpers.Gcf.get_active()]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __patched_draw_all(*args, **kwargs):
|
def __patched_draw_all(*args, **kwargs):
|
||||||
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
|
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
|
||||||
|
Loading…
Reference in New Issue
Block a user