diff --git a/examples/frameworks/kerastuner/keras_tuner_cifar.py b/examples/frameworks/kerastuner/keras_tuner_cifar.py index 5f23d740..79002a1d 100644 --- a/examples/frameworks/kerastuner/keras_tuner_cifar.py +++ b/examples/frameworks/kerastuner/keras_tuner_cifar.py @@ -1,6 +1,6 @@ """Keras Tuner CIFAR10 example for the TensorFlow blog post.""" -import kerastuner as kt +import keras_tuner as kt import tensorflow as tf import tensorflow_datasets as tfds from clearml.external.kerastuner import ClearmlTunerLogger @@ -9,6 +9,7 @@ from clearml import Task physical_devices = tf.config.list_physical_devices('GPU') if physical_devices: + tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU') tf.config.experimental.set_memory_growth(physical_devices[0], True)