From 772fd7f7508d0f5b31428084856b2e5485a013e2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 8 Aug 2020 13:13:21 +0300 Subject: [PATCH] Add Keras Tuner CIFAR10 example --- .../kerastuner/keras_tuner_cifar.py | 74 +++++++++++++++++++ .../frameworks/kerastuner/requirements.txt | 3 + .../frameworks/pytorch/manual_model_upload.py | 2 +- examples/reporting/artifacts.py | 1 - 4 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 examples/frameworks/kerastuner/keras_tuner_cifar.py create mode 100644 examples/frameworks/kerastuner/requirements.txt diff --git a/examples/frameworks/kerastuner/keras_tuner_cifar.py b/examples/frameworks/kerastuner/keras_tuner_cifar.py new file mode 100644 index 00000000..eb0dcf8a --- /dev/null +++ b/examples/frameworks/kerastuner/keras_tuner_cifar.py @@ -0,0 +1,74 @@ +"""Keras Tuner CIFAR10 example for the TensorFlow blog post.""" + +import kerastuner as kt +import tensorflow as tf +import tensorflow_datasets as tfds +from trains.external.kerastuner import TrainsTunerLogger + +from trains import Task + +physical_devices = tf.config.list_physical_devices('GPU') +if physical_devices: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + + +def build_model(hp): + inputs = tf.keras.Input(shape=(32, 32, 3)) + x = inputs + for i in range(hp.Int('conv_blocks', 3, 5, default=3)): + filters = hp.Int('filters_' + str(i), 32, 256, step=32) + for _ in range(2): + x = tf.keras.layers.Convolution2D( + filters, kernel_size=(3, 3), padding='same')(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + if hp.Choice('pooling_' + str(i), ['avg', 'max']) == 'max': + x = tf.keras.layers.MaxPool2D()(x) + else: + x = tf.keras.layers.AvgPool2D()(x) + x = tf.keras.layers.GlobalAvgPool2D()(x) + x = tf.keras.layers.Dense( + hp.Int('hidden_size', 30, 100, step=10, default=50), + activation='relu')(x) + x = tf.keras.layers.Dropout( + hp.Float('dropout', 0, 0.5, step=0.1, default=0.5))(x) + outputs = tf.keras.layers.Dense(10, activation='softmax')(x) + + model = tf.keras.Model(inputs, outputs) + model.compile( + optimizer=tf.keras.optimizers.Adam( + hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + return model + + +task = Task.init('examples', 'kerastuner cifar10 tuning') + +tuner = kt.Hyperband( + build_model, + project_name='kt examples', + logger=TrainsTunerLogger(), + objective='val_accuracy', + max_epochs=10, + hyperband_iterations=6) + +data = tfds.load('cifar10') +train_ds, test_ds = data['train'], data['test'] + + +def standardize_record(record): + return tf.cast(record['image'], tf.float32) / 255., record['label'] + + +train_ds = train_ds.map(standardize_record).cache().batch(64).shuffle(10000) +test_ds = test_ds.map(standardize_record).cache().batch(64) + +tuner.search(train_ds, + validation_data=test_ds, + callbacks=[tf.keras.callbacks.EarlyStopping(patience=1), + tf.keras.callbacks.TensorBoard(), + ]) + +best_model = tuner.get_best_models(1)[0] +best_hyperparameters = tuner.get_best_hyperparameters(1)[0] diff --git a/examples/frameworks/kerastuner/requirements.txt b/examples/frameworks/kerastuner/requirements.txt new file mode 100644 index 00000000..914c8f92 --- /dev/null +++ b/examples/frameworks/kerastuner/requirements.txt @@ -0,0 +1,3 @@ +keras-tuner +tensorflow>=2.0 +tensorflow-datasets \ No newline at end of file diff --git a/examples/frameworks/pytorch/manual_model_upload.py b/examples/frameworks/pytorch/manual_model_upload.py index 659e8125..dfd963e0 100644 --- a/examples/frameworks/pytorch/manual_model_upload.py +++ b/examples/frameworks/pytorch/manual_model_upload.py @@ -39,5 +39,5 @@ task.connect_label_enumeration(labels) # storing the model, it will have the task network configuration and label enumeration print('Any model stored from this point onwards, will contain both model_config and label_enumeration') -torch.save(model, os.path.join(gettempdir(), "model")) +torch.save(model, os.path.join(gettempdir(), "model.pt")) print('Model saved') diff --git a/examples/reporting/artifacts.py b/examples/reporting/artifacts.py index ec24b567..6ba0168c 100644 --- a/examples/reporting/artifacts.py +++ b/examples/reporting/artifacts.py @@ -40,7 +40,6 @@ task.upload_artifact('pillow_image', im) task.upload_artifact('local folder', artifact_object=os.path.join('data_samples')) # add and upload a wildcard task.upload_artifact('wildcard jpegs', artifact_object=os.path.join('data_samples', '*.jpg')) - # do something here sleep(1.) print(df)