Fix examples

This commit is contained in:
allegroai 2024-03-17 15:06:40 +02:00
parent ed51acb2ea
commit 4a08a5f79b
2 changed files with 13 additions and 7 deletions

View File

@ -106,12 +106,12 @@ task.set_model_label_enumeration(labels)
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example') output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False) board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5')) model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.keras'))
# load previous model, if it is there # load previous model, if it is there
# noinspection PyBroadException # noinspection PyBroadException
try: try:
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5')) model.load_weights(os.path.join(output_folder, 'weight.1.keras'))
except Exception: except Exception:
pass pass

View File

@ -93,7 +93,7 @@ train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir) test_summary_writer = tf.summary.create_file_writer(test_log_dir)
# Set up checkpoints manager # Set up checkpoints manager
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model)
manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3) manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3)
ckpt.restore(manager.latest_checkpoint) ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint: if manager.latest_checkpoint:
@ -129,7 +129,13 @@ for epoch in range(EPOCHS):
test_accuracy.result()*100)) test_accuracy.result()*100))
# Reset the metrics for the next epoch # Reset the metrics for the next epoch
try:
train_loss.reset_states() train_loss.reset_states()
train_accuracy.reset_states() train_accuracy.reset_states()
test_loss.reset_states() test_loss.reset_states()
test_accuracy.reset_states() test_accuracy.reset_states()
except AttributeError:
train_loss.reset_state()
train_accuracy.reset_state()
test_loss.reset_state()
test_accuracy.reset_state()