mirror of
https://github.com/clearml/clearml
synced 2025-04-28 02:01:51 +00:00
Remove cast warning in Tensorflow v2 example
This commit is contained in:
parent
3d9683f290
commit
824808d38b
@ -19,8 +19,8 @@ mnist = tf.keras.datasets.mnist
|
|||||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
|
||||||
# Add a channels dimension
|
# Add a channels dimension
|
||||||
x_train = x_train[..., tf.newaxis]
|
x_train = x_train[..., tf.newaxis].astype('float32')
|
||||||
x_test = x_test[..., tf.newaxis]
|
x_test = x_test[..., tf.newaxis].astype('float32')
|
||||||
|
|
||||||
# Use tf.data to batch and shuffle the dataset
|
# Use tf.data to batch and shuffle the dataset
|
||||||
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
|
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
|
||||||
@ -31,10 +31,10 @@ test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
|
|||||||
class MyModel(Model):
|
class MyModel(Model):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyModel, self).__init__()
|
super(MyModel, self).__init__()
|
||||||
self.conv1 = Conv2D(32, 3, activation='relu')
|
self.conv1 = Conv2D(32, 3, activation='relu', dtype=tf.float32)
|
||||||
self.flatten = Flatten()
|
self.flatten = Flatten()
|
||||||
self.d1 = Dense(128, activation='relu')
|
self.d1 = Dense(128, activation='relu', dtype=tf.float32)
|
||||||
self.d2 = Dense(10, activation='softmax')
|
self.d2 = Dense(10, activation='softmax', dtype=tf.float32)
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user