mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Fix examples
This commit is contained in:
parent
ed51acb2ea
commit
4a08a5f79b
@ -106,12 +106,12 @@ task.set_model_label_enumeration(labels)
|
|||||||
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
|
output_folder = os.path.join(tempfile.gettempdir(), 'keras_example')
|
||||||
|
|
||||||
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
|
board = TensorBoard(histogram_freq=1, log_dir=output_folder, write_images=False)
|
||||||
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.hdf5'))
|
model_store = ModelCheckpoint(filepath=os.path.join(output_folder, 'weight.{epoch}.keras'))
|
||||||
|
|
||||||
# load previous model, if it is there
|
# load previous model, if it is there
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model.load_weights(os.path.join(output_folder, 'weight.1.hdf5'))
|
model.load_weights(os.path.join(output_folder, 'weight.1.keras'))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
|||||||
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
|
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
|
||||||
|
|
||||||
# Set up checkpoints manager
|
# Set up checkpoints manager
|
||||||
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
|
ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model)
|
||||||
manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3)
|
manager = tf.train.CheckpointManager(ckpt, os.path.join(gettempdir(), 'tf_ckpts'), max_to_keep=3)
|
||||||
ckpt.restore(manager.latest_checkpoint)
|
ckpt.restore(manager.latest_checkpoint)
|
||||||
if manager.latest_checkpoint:
|
if manager.latest_checkpoint:
|
||||||
@ -129,7 +129,13 @@ for epoch in range(EPOCHS):
|
|||||||
test_accuracy.result()*100))
|
test_accuracy.result()*100))
|
||||||
|
|
||||||
# Reset the metrics for the next epoch
|
# Reset the metrics for the next epoch
|
||||||
|
try:
|
||||||
train_loss.reset_states()
|
train_loss.reset_states()
|
||||||
train_accuracy.reset_states()
|
train_accuracy.reset_states()
|
||||||
test_loss.reset_states()
|
test_loss.reset_states()
|
||||||
test_accuracy.reset_states()
|
test_accuracy.reset_states()
|
||||||
|
except AttributeError:
|
||||||
|
train_loss.reset_state()
|
||||||
|
train_accuracy.reset_state()
|
||||||
|
test_loss.reset_state()
|
||||||
|
test_accuracy.reset_state()
|
||||||
|
Loading…
Reference in New Issue
Block a user