diff --git a/examples/tensorflow_v2_mnist.py b/examples/tensorflow_v2_mnist.py index 96a23869..65edd22a 100644 --- a/examples/tensorflow_v2_mnist.py +++ b/examples/tensorflow_v2_mnist.py @@ -19,8 +19,8 @@ mnist = tf.keras.datasets.mnist x_train, x_test = x_train / 255.0, x_test / 255.0 # Add a channels dimension -x_train = x_train[..., tf.newaxis] -x_test = x_test[..., tf.newaxis] +x_train = x_train[..., tf.newaxis].astype('float32') +x_test = x_test[..., tf.newaxis].astype('float32') # Use tf.data to batch and shuffle the dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32) @@ -31,10 +31,10 @@ test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) class MyModel(Model): def __init__(self): super(MyModel, self).__init__() - self.conv1 = Conv2D(32, 3, activation='relu') + self.conv1 = Conv2D(32, 3, activation='relu', dtype=tf.float32) self.flatten = Flatten() - self.d1 = Dense(128, activation='relu') - self.d2 = Dense(10, activation='softmax') + self.d1 = Dense(128, activation='relu', dtype=tf.float32) + self.d2 = Dense(10, activation='softmax', dtype=tf.float32) def call(self, x): x = self.conv1(x)