hexabot/nlu/data_loaders/mnist.py

30 lines
947 B
Python
Raw Permalink Normal View History

2024-09-10 09:50:11 +00:00
import tensorflow as tf
import boilerplate as tfbp
@tfbp.default_export
class MNIST(tfbp.DataLoader):
default_hparams = {"batch_size": 32}
def __call__(self):
train_data, test_data = tf.keras.datasets.mnist.load_data()
test_data = tf.data.Dataset.from_tensor_slices(test_data)
if self.method in ["fit", "train"]:
train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000)
test_data = test_data.shuffle(10000)
train_data = self._transform_dataset(train_data)
return train_data, test_data
return self._transform_dataset(test_data)
def _transform_dataset(self, dataset):
dataset = dataset.batch(self.hparams.batch_size)
return dataset.map(
lambda x, y: (
tf.reshape(tf.cast(x, tf.float32) / 255.0, [-1, 28 * 28]), # type: ignore
tf.cast(y, tf.int64),
)
)