hexabot/nlu/models/mlp.py

90 lines
2.7 KiB
Python
Raw Normal View History

2024-09-10 09:50:11 +00:00
import tensorflow as tf
from keras import layers as tfkl
import boilerplate as tfbp
@tfbp.default_export
class MLP(tfbp.Model):
default_hparams = {
"layer_sizes": [512, 10],
"learning_rate": 0.001,
"num_epochs": 10,
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.forward = tf.keras.Sequential()
for hidden_size in self.hparams.layer_sizes[:-1]:
self.forward.add(tfkl.Dense(hidden_size, activation=tf.nn.relu))
self.forward.add(
tfkl.Dense(self.hparams.layer_sizes[-1], activation=tf.nn.softmax)
)
self.loss = tf.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.optimizers.Adam(self.hparams.learning_rate)
def call(self, x):
return self.forward(x)
@tfbp.runnable
def fit(self, data_loader):
"""Example using keras training loop."""
train_data, valid_data = data_loader.load()
self.compile(self.optimizer, self.loss)
super().fit(
x=train_data,
validation_data=valid_data,
validation_steps=32, # validate 32 batches at a time
validation_freq=1, # validate every 1 epoch
epochs=self.hparams.num_epochs,
shuffle=False, # dataset instances already handle shuffling
)
self.save()
@tfbp.runnable
def train(self, data_loader):
"""Example using custom training loop."""
step = 0
train_data, valid_data = data_loader()
# Allow to call `next` builtin indefinitely.
valid_data = iter(valid_data.repeat())
for epoch in range(self.hparams.num_epochs):
for x, y in train_data:
with tf.GradientTape() as g:
train_loss = self.loss(y, self(x))
grads = g.gradient(train_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
# Validate every 1000 training steps.
if step % 1000 == 0:
x, y = next(valid_data)
valid_loss = self.loss(y, self(x))
print(
f"step {step} (train_loss={train_loss} valid_loss={valid_loss})"
)
step += 1
print(f"epoch {epoch} finished")
self.save()
@tfbp.runnable
def evaluate(self, data_loader):
n = 0
accuracy = 0
test_data = data_loader()
for x, y in test_data:
true_pos = tf.math.equal(y, tf.math.argmax(self(x), axis=-1))
for i in true_pos.numpy():
n += 1
accuracy += (i - accuracy) / n
print(accuracy)