mirror of
https://github.com/hexastack/hexabot
synced 2024-12-28 23:02:03 +00:00
90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
import tensorflow as tf
|
|
from keras import layers as tfkl
|
|
|
|
import boilerplate as tfbp
|
|
|
|
|
|
@tfbp.default_export
|
|
class MLP(tfbp.Model):
|
|
default_hparams = {
|
|
"layer_sizes": [512, 10],
|
|
"learning_rate": 0.001,
|
|
"num_epochs": 10,
|
|
}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.forward = tf.keras.Sequential()
|
|
|
|
for hidden_size in self.hparams.layer_sizes[:-1]:
|
|
self.forward.add(tfkl.Dense(hidden_size, activation=tf.nn.relu))
|
|
|
|
self.forward.add(
|
|
tfkl.Dense(self.hparams.layer_sizes[-1], activation=tf.nn.softmax)
|
|
)
|
|
|
|
self.loss = tf.losses.SparseCategoricalCrossentropy()
|
|
self.optimizer = tf.optimizers.Adam(self.hparams.learning_rate)
|
|
|
|
def call(self, x):
|
|
return self.forward(x)
|
|
|
|
@tfbp.runnable
|
|
def fit(self, data_loader):
|
|
"""Example using keras training loop."""
|
|
train_data, valid_data = data_loader.load()
|
|
|
|
self.compile(self.optimizer, self.loss)
|
|
super().fit(
|
|
x=train_data,
|
|
validation_data=valid_data,
|
|
validation_steps=32, # validate 32 batches at a time
|
|
validation_freq=1, # validate every 1 epoch
|
|
epochs=self.hparams.num_epochs,
|
|
shuffle=False, # dataset instances already handle shuffling
|
|
)
|
|
self.save()
|
|
|
|
@tfbp.runnable
|
|
def train(self, data_loader):
|
|
"""Example using custom training loop."""
|
|
step = 0
|
|
train_data, valid_data = data_loader()
|
|
|
|
# Allow to call `next` builtin indefinitely.
|
|
valid_data = iter(valid_data.repeat())
|
|
|
|
for epoch in range(self.hparams.num_epochs):
|
|
for x, y in train_data:
|
|
|
|
with tf.GradientTape() as g:
|
|
train_loss = self.loss(y, self(x))
|
|
|
|
grads = g.gradient(train_loss, self.trainable_variables)
|
|
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
|
|
|
# Validate every 1000 training steps.
|
|
if step % 1000 == 0:
|
|
x, y = next(valid_data)
|
|
valid_loss = self.loss(y, self(x))
|
|
print(
|
|
f"step {step} (train_loss={train_loss} valid_loss={valid_loss})"
|
|
)
|
|
step += 1
|
|
|
|
print(f"epoch {epoch} finished")
|
|
self.save()
|
|
|
|
@tfbp.runnable
|
|
def evaluate(self, data_loader):
|
|
n = 0
|
|
accuracy = 0
|
|
test_data = data_loader()
|
|
for x, y in test_data:
|
|
true_pos = tf.math.equal(y, tf.math.argmax(self(x), axis=-1))
|
|
for i in true_pos.numpy():
|
|
n += 1
|
|
accuracy += (i - accuracy) / n
|
|
print(accuracy)
|