diff --git a/examples/keras_tensorboard.py b/examples/keras_tensorboard.py index dd369237..ecf137f8 100644 --- a/examples/keras_tensorboard.py +++ b/examples/keras_tensorboard.py @@ -6,7 +6,9 @@ # 2 seconds per epoch on a K520 GPU. from __future__ import print_function -import io +import argparse +import tempfile +import os import numpy as np from keras.callbacks import TensorBoard, ModelCheckpoint @@ -15,8 +17,6 @@ from keras.models import Sequential from keras.layers.core import Dense, Activation from keras.optimizers import RMSprop from keras.utils import np_utils -# TODO: test these methods binding -from keras.models import load_model, save_model, model_from_json import tensorflow as tf from trains import Task @@ -25,6 +25,7 @@ class TensorBoardImage(TensorBoard): @staticmethod def make_image(tensor): from PIL import Image + import io tensor = np.stack((tensor, tensor, tensor), axis=2) height, width, channels = tensor.shape image = Image.fromarray(tensor) @@ -49,19 +50,17 @@ class TensorBoardImage(TensorBoard): self.writer.add_summary(summary, epoch) -batch_size = 128 -nb_classes = 10 -nb_epoch = 6 +parser = argparse.ArgumentParser(description='Keras MNIST Example') +parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') +parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)') +args = parser.parse_args() # the data, shuffled and split between train and test sets +nb_classes = 10 (X_train, y_train), (X_test, y_test) = mnist.load_data() -X_train = X_train.reshape(60000, 784) -X_test = X_test.reshape(10000, 784) -X_train = X_train.astype('float32') -X_test = X_test.astype('float32') -X_train /= 255. -X_test /= 255. +X_train = X_train.reshape(60000, 784).astype('float32')/255. +X_test = X_test.reshape(10000, 784).astype('float32')/255. print(X_train.shape[0], 'train samples') print(X_test.shape[0], 'test samples') @@ -91,21 +90,23 @@ model.compile(loss='categorical_crossentropy', # Connecting TRAINS task = Task.init(project_name='examples', task_name='Keras with TensorBoard example') -# setting model outputs +# Advanced: setting model class enumeration labels = dict(('digit_%d' % i, i) for i in range(10)) task.set_model_label_enumeration(labels) -board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False) -model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5') +output_folder = os.path.join(tempfile.gettempdir(), 'keras_example') + +board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False) +model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5')) # load previous model, if it is there try: - model.load_weights('/tmp/histogram_example/weight.1.hdf5') + model.load_weights(os.path.join(output_folder, 'weight.1.hdf5')) except: pass history = model.fit(X_train, Y_train, - batch_size=batch_size, epochs=nb_epoch, + batch_size=args.batch_size, epochs=args.epochs, callbacks=[board, model_store], verbose=1, validation_data=(X_test, Y_test)) score = model.evaluate(X_test, Y_test, verbose=0) diff --git a/examples/manual_reporting.py b/examples/manual_reporting.py index dcf25304..7745cbbc 100644 --- a/examples/manual_reporting.py +++ b/examples/manual_reporting.py @@ -7,7 +7,7 @@ from trains import Task task = Task.init(project_name='examples', task_name='Manual reporting') -# example python logger +# standard python logging logging.getLogger().setLevel('DEBUG') logging.debug('This is a debug message') logging.info('This is an info message') @@ -23,7 +23,7 @@ except ImportError: pass # get TRAINS logger object for any metrics / reports -logger = task.get_logger() +logger = Task.current_task().get_logger() # log text logger.console("hello") @@ -34,7 +34,7 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200) # report histogram histogram = np.random.randint(10, size=10) -logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram) +logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram) # report confusion matrix confusion = np.random.randint(10, size=(10, 10)) diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index c02b826c..63ef9e35 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -2,6 +2,7 @@ import collections import json import six +import numpy as np from threading import Thread, Event from ..base import InterfaceBase @@ -157,8 +158,15 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param iter: Iteration number :type value: int """ + try: + def default(o): + if isinstance(o, np.int64): + return int(o) + except Exception: + default = None + if isinstance(plot, dict): - plot = json.dumps(plot) + plot = json.dumps(plot, default=default) elif not isinstance(plot, six.string_types): raise ValueError('Plot should be a string or a dict') ev = PlotEvent(metric=self._normalize_name(title), diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index 11364a48..219d82be 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import os import sys from tempfile import mkstemp @@ -41,7 +43,8 @@ class PatchedMatplotlib: try: # we support matplotlib version 2.0.0 and above import matplotlib - if int(matplotlib.__version__.split('.')[0]) < 2: + matplot_major_version = int(matplotlib.__version__.split('.')[0]) + if matplot_major_version < 2: LoggerRoot.get_base_logger().warning( 'matplotlib binding supports version 2.0 and above, found version {}'.format( matplotlib.__version__)) @@ -63,6 +66,7 @@ class PatchedMatplotlib: PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_figure = figure.Figure.show + plt.show = PatchedMatplotlib.patched_show figure.Figure.show = PatchedMatplotlib.patched_figure_show sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow @@ -111,13 +115,26 @@ class PatchedMatplotlib: @staticmethod def patched_figure_show(self, *args, **kw): + if hasattr(self, '_trains_show'): + # flag will be cleared when calling clf() (object will be replaced) + return PatchedMatplotlib._patched_original_figure(self, *args, **kw) + try: + self._trains_show = True + except Exception: + pass PatchedMatplotlib._report_figure(set_active=False, specific_fig=self) ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw) return ret @staticmethod def patched_show(*args, **kw): - PatchedMatplotlib._report_figure() + # noinspection PyBroadException + try: + figures = PatchedMatplotlib._get_output_figures(None, all_figures=True) + for figure in figures: + PatchedMatplotlib._report_figure(stored_figure=figure) + except Exception: + pass ret = PatchedMatplotlib._patched_original_plot(*args, **kw) if PatchedMatplotlib._current_task and running_remotely(): # clear the current plot, because no one else will @@ -155,10 +172,17 @@ class PatchedMatplotlib: else: mpl_fig = specific_fig + # mark as processed, so nested calls to figure.show will do nothing + try: + mpl_fig._trains_show = True + except Exception: + pass + # convert to plotly image = None plotly_fig = None image_format = 'jpeg' + fig_dpi = 300 if not force_save_as_image: image_format = 'svg' # noinspection PyBroadException @@ -168,12 +192,39 @@ class PatchedMatplotlib: if matplotlylib: renderer = matplotlylib.PlotlyRenderer() matplotlylib.Exporter(renderer, close_mpl=False).run(fig) + x_ticks = list(renderer.current_mpl_ax.get_xticklabels()) + if x_ticks: + try: + # check if all values can be cast to float + values = [float(t.get_text().replace('−', '-')) for t in x_ticks] + except: + try: + renderer.plotly_fig['layout']['xaxis1'].update({ + 'ticktext': [t.get_text() for t in x_ticks], + 'tickvals': [t.get_position()[0] for t in x_ticks], + }) + except: + pass + y_ticks = list(renderer.current_mpl_ax.get_yticklabels()) + if y_ticks: + try: + # check if all values can be cast to float + values = [float(t.get_text().replace('−', '-')) for t in y_ticks] + except: + try: + renderer.plotly_fig['layout']['yaxis1'].update({ + 'ticktext': [t.get_text() for t in y_ticks], + 'tickvals': [t.get_position()[1] for t in y_ticks], + }) + except: + pass return renderer.plotly_fig plotly_fig = our_mpl_to_plotly(mpl_fig) except Exception as ex: # this was an image, change format to png image_format = 'jpeg' if 'selfie' in str(ex) else 'png' + fig_dpi = 300 # plotly could not serialize the plot, we should convert to image if not plotly_fig: @@ -183,13 +234,13 @@ class PatchedMatplotlib: # first try SVG if we fail then fallback to png buffer_ = BytesIO() a_plt = specific_fig if specific_fig is not None else plt - a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) buffer_.seek(0) except Exception: image_format = 'png' buffer_ = BytesIO() a_plt = specific_fig if specific_fig is not None else plt - a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) buffer_.seek(0) fd, image = mkstemp(suffix='.'+image_format) os.write(fd, buffer_.read()) @@ -253,6 +304,17 @@ class PatchedMatplotlib: return + @staticmethod + def _get_output_figures(stored_figure, all_figures): + try: + from matplotlib import _pylab_helpers + if all_figures: + return list(_pylab_helpers.Gcf.figs.values()) + else: + return [stored_figure] or [_pylab_helpers.Gcf.get_active()] + except Exception: + return [] + @staticmethod def __patched_draw_all(*args, **kwargs): recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard