diff --git a/examples/frameworks/keras/keras_tensorboard.py b/examples/frameworks/keras/keras_tensorboard.py index 6762965f..1d944898 100644 --- a/examples/frameworks/keras/keras_tensorboard.py +++ b/examples/frameworks/keras/keras_tensorboard.py @@ -106,12 +106,12 @@ task.set_model_label_enumeration(labels) output_folder = os.path.join(tempfile.gettempdir(), 'keras_example') board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False) -model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5')) +model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.keras')) # load previous model, if it is there # noinspection PyBroadException try: - model.load_weights(os.path.join(output_folder, 'weight.1.hdf5')) + model.load_weights(os.path.join(output_folder, 'weight.1.keras')) except Exception: pass diff --git a/examples/frameworks/tensorflow/tensorflow_mnist.py b/examples/frameworks/tensorflow/tensorflow_mnist.py index 4eb5333f..26211275 100644 --- a/examples/frameworks/tensorflow/tensorflow_mnist.py +++ b/examples/frameworks/tensorflow/tensorflow_mnist.py @@ -93,7 +93,7 @@ train_summary_writer = tf.summary.create_file_writer(train_log_dir) test_summary_writer = tf.summary.create_file_writer(test_log_dir) # Set up checkpoints manager -ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) +ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model) manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3) ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: @@ -129,7 +129,13 @@ for epoch in range(EPOCHS): test_accuracy.result()*100)) # Reset the metrics for the next epoch - train_loss.reset_states() - train_accuracy.reset_states() - test_loss.reset_states() - test_accuracy.reset_states() + try: + train_loss.reset_states() + train_accuracy.reset_states() + test_loss.reset_states() + test_accuracy.reset_states() + except AttributeError: + train_loss.reset_state() + train_accuracy.reset_state() + test_loss.reset_state() + test_accuracy.reset_state()