import tensorflow as tf from keras import layers as tfkl import boilerplate as tfbp @tfbp.default_export class MLP(tfbp.Model): default_hparams = { "layer_sizes": [512, 10], "learning_rate": 0.001, "num_epochs": 10, } def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.forward = tf.keras.Sequential() for hidden_size in self.hparams.layer_sizes[:-1]: self.forward.add(tfkl.Dense(hidden_size, activation=tf.nn.relu)) self.forward.add( tfkl.Dense(self.hparams.layer_sizes[-1], activation=tf.nn.softmax) ) self.loss = tf.losses.SparseCategoricalCrossentropy() self.optimizer = tf.optimizers.Adam(self.hparams.learning_rate) def call(self, x): return self.forward(x) @tfbp.runnable def fit(self, data_loader): """Example using keras training loop.""" train_data, valid_data = data_loader.load() self.compile(self.optimizer, self.loss) super().fit( x=train_data, validation_data=valid_data, validation_steps=32, # validate 32 batches at a time validation_freq=1, # validate every 1 epoch epochs=self.hparams.num_epochs, shuffle=False, # dataset instances already handle shuffling ) self.save() @tfbp.runnable def train(self, data_loader): """Example using custom training loop.""" step = 0 train_data, valid_data = data_loader() # Allow to call `next` builtin indefinitely. valid_data = iter(valid_data.repeat()) for epoch in range(self.hparams.num_epochs): for x, y in train_data: with tf.GradientTape() as g: train_loss = self.loss(y, self(x)) grads = g.gradient(train_loss, self.trainable_variables) self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) # Validate every 1000 training steps. if step % 1000 == 0: x, y = next(valid_data) valid_loss = self.loss(y, self(x)) print( f"step {step} (train_loss={train_loss} valid_loss={valid_loss})" ) step += 1 print(f"epoch {epoch} finished") self.save() @tfbp.runnable def evaluate(self, data_loader): n = 0 accuracy = 0 test_data = data_loader() for x, y in test_data: true_pos = tf.math.equal(y, tf.math.argmax(self(x), axis=-1)) for i in true_pos.numpy(): n += 1 accuracy += (i - accuracy) / n print(accuracy)