Fix examples

This commit is contained in:
allegroai 2024-03-17 15:06:40 +02:00
parent ed51acb2ea
commit 4a08a5f79b
2 changed files with 13 additions and 7 deletions

View File

@ -106,12 +106,12 @@ task.set_model_label_enumeration(labels)
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.keras'))
# load previous model, if it is there
# noinspection PyBroadException
try:
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
model.load_weights(os.path.join(output_folder, 'weight.1.keras'))
except Exception:
pass

View File

@ -93,7 +93,7 @@ train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
# Set up checkpoints manager
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model)
manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
@ -129,7 +129,13 @@ for epoch in range(EPOCHS):
test_accuracy.result()*100))
# Reset the metrics for the next epoch
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
try:
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
except AttributeError:
train_loss.reset_state()
train_accuracy.reset_state()
test_loss.reset_state()
test_accuracy.reset_state()