clearml/examples/tensorflow_eager.py

359 lines
13 KiB
Python
Raw Normal View History

2019-06-10 17:00:28 +00:00
# TRAINS - Example of tensorflow eager mode, model logging and tensorboard
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
Sample usage:
python mnist.py --help
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from trains import Task
2019-07-17 06:10:32 +00:00
tf.compat.v1.enable_eager_execution()
2019-06-10 17:00:28 +00:00
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('data_num', 100, """Flag of type integer""")
tf.app.flags.DEFINE_string('img_path', './img', """Flag of type string""")
layers = tf.keras.layers
FLAGS = None
class Discriminator(tf.keras.Model):
"""GAN Discriminator.
A network to differentiate between generated and real handwritten digits.
"""
def __init__(self, data_format):
"""Creates a model for discriminating between real and generated digits.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(Discriminator, self).__init__(name='')
if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28]
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
self.conv1 = layers.Conv2D(
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
self.conv2 = layers.Conv2D(
128, 5, data_format=data_format, activation=tf.tanh)
self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
self.flatten = layers.Flatten()
self.fc1 = layers.Dense(1024, activation=tf.tanh)
self.fc2 = layers.Dense(1, activation=None)
def call(self, inputs):
"""Return two logits per image estimating input authenticity.
Users should invoke __call__ to run the network, which delegates to this
method (and not call this method directly).
Args:
inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
or [batch_size, 1, 28, 28]
Returns:
A Tensor with shape [batch_size] containing logits estimating
the probability that corresponding digit is real.
"""
x = tf.reshape(inputs, self._input_shape)
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
class Generator(tf.keras.Model):
"""Generator of handwritten digits similar to the ones in the MNIST dataset.
"""
def __init__(self, data_format):
"""Creates a model for discriminating between real and generated digits.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(Generator, self).__init__(name='')
self.data_format = data_format
# We are using 128 6x6 channels as input to the first deconvolution layer
if data_format == 'channels_first':
self._pre_conv_shape = [-1, 128, 6, 6]
else:
assert data_format == 'channels_last'
self._pre_conv_shape = [-1, 6, 6, 128]
self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
# In call(), we reshape the output of fc1 to _pre_conv_shape
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
self.conv1 = layers.Conv2DTranspose(
64, 4, strides=2, activation=None, data_format=data_format)
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
self.conv2 = layers.Conv2DTranspose(
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
def call(self, inputs):
"""Return a batch of generated images.
Users should invoke __call__ to run the network, which delegates to this
method (and not call this method directly).
Args:
inputs: A batch of noise vectors as a Tensor with shape
[batch_size, length of noise vectors].
Returns:
A Tensor containing generated images. If data_format is 'channels_last',
the shape of returned images is [batch_size, 28, 28, 1], else
[batch_size, 1, 28, 28]
"""
x = self.fc1(inputs)
x = tf.reshape(x, shape=self._pre_conv_shape)
x = self.conv1(x)
x = self.conv2(x)
return x
def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
"""Original discriminator loss for GANs, with label smoothing.
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
details.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
Returns:
A scalar loss Tensor.
"""
2019-07-17 06:10:32 +00:00
loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
2019-06-10 17:00:28 +00:00
tf.ones_like(discriminator_real_outputs),
discriminator_real_outputs,
label_smoothing=0.25)
2019-07-17 06:10:32 +00:00
loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
2019-06-10 17:00:28 +00:00
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
loss = loss_on_real + loss_on_generated
tf.contrib.summary.scalar('discriminator_loss', loss)
return loss
def generator_loss(discriminator_gen_outputs):
"""Original generator loss for GANs.
L = -log(sigmoid(D(G(z))))
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
Returns:
A scalar loss Tensor.
"""
2019-07-17 06:10:32 +00:00
loss = tf.compat.v1.losses.sigmoid_cross_entropy(
2019-06-10 17:00:28 +00:00
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
tf.contrib.summary.scalar('generator_loss', loss)
return loss
def train_one_epoch(generator, discriminator, generator_optimizer,
discriminator_optimizer, dataset, step_counter,
log_interval, noise_dim):
"""Train `generator` and `discriminator` models on `dataset`.
Args:
generator: Generator model.
discriminator: Discriminator model.
generator_optimizer: Optimizer to use for generator.
discriminator_optimizer: Optimizer to use for discriminator.
dataset: Dataset of images to train on.
step_counter: An integer variable, used to write summaries regularly.
log_interval: How many steps to wait between logging and collecting
summaries.
noise_dim: Dimension of noise vector to use.
"""
total_generator_loss = 0.0
total_discriminator_loss = 0.0
for (batch_index, images) in enumerate(dataset):
with tf.device('/cpu:0'):
2019-07-17 06:10:32 +00:00
tf.compat.v1.assign_add(step_counter, 1)
2019-06-10 17:00:28 +00:00
with tf.contrib.summary.record_summaries_every_n_global_steps(
log_interval, global_step=step_counter):
current_batch_size = images.shape[0]
2019-07-17 06:10:32 +00:00
noise = tf.random.uniform(
2019-06-10 17:00:28 +00:00
shape=[current_batch_size, noise_dim],
minval=-1.,
maxval=1.,
seed=batch_index)
# we can use 2 tapes or a single persistent tape.
# Using two tapes is memory efficient since intermediate tensors can be
# released between the two .gradient() calls below
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise)
tf.contrib.summary.image(
'generated_images',
tf.reshape(generated_images, [-1, 28, 28, 1]),
max_images=10)
discriminator_gen_outputs = discriminator(generated_images)
discriminator_real_outputs = discriminator(images)
discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
discriminator_gen_outputs)
total_discriminator_loss += discriminator_loss_val
generator_loss_val = generator_loss(discriminator_gen_outputs)
total_generator_loss += generator_loss_val
generator_grad = gen_tape.gradient(generator_loss_val,
generator.variables)
discriminator_grad = disc_tape.gradient(discriminator_loss_val,
discriminator.variables)
generator_optimizer.apply_gradients(
zip(generator_grad, generator.variables))
discriminator_optimizer.apply_gradients(
zip(discriminator_grad, discriminator.variables))
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
print('Batch #%d\tAverage Generator Loss: %.6f\t'
'Average Discriminator Loss: %.6f' %
(batch_index, total_generator_loss / batch_index,
total_discriminator_loss / batch_index))
def main(_):
(device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
data = input_data.read_data_sets(FLAGS.data_dir)
dataset = (
tf.data.Dataset.from_tensor_slices(data.train.images[:1280]).shuffle(60000)
.batch(FLAGS.batch_size))
# Create the models and optimizers.
model_objects = {
'generator': Generator(data_format),
'discriminator': Discriminator(data_format),
2019-07-17 06:10:32 +00:00
'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
'step_counter': tf.compat.v1.train.get_or_create_global_step(),
2019-06-10 17:00:28 +00:00
}
# Prepare summary writer and checkpoint info
summary_writer = tf.contrib.summary.create_file_writer(
FLAGS.output_dir, flush_millis=1000)
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_cpkt:
print('Using latest checkpoint at ' + latest_cpkt)
checkpoint = tf.train.Checkpoint(**model_objects)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(latest_cpkt)
with tf.device(device):
for _ in range(3):
start = time.time()
with summary_writer.as_default():
train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
noise_dim=FLAGS.noise, **model_objects)
end = time.time()
checkpoint.save(checkpoint_prefix)
print('\nTrain time for epoch #%d (step %d): %f' %
(checkpoint.save_counter.numpy(),
checkpoint.step_counter.numpy(),
end - start))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help=('Directory for storing input data (default '
'/tmp/tensorflow/mnist/input_data)'))
parser.add_argument(
'--batch-size',
type=int,
default=16,
metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument(
'--log-interval',
type=int,
default=1,
metavar='N',
help=('number of batches between logging and writing summaries '
'(default: 100)'))
parser.add_argument(
'--output_dir',
type=str,
default='/tmp/tensorflow/',
metavar='DIR',
help='Directory to write TensorBoard summaries (defaults to none)')
parser.add_argument(
'--checkpoint_dir',
type=str,
default='/tmp/tensorflow/mnist/checkpoints/',
metavar='DIR',
help=('Directory to save checkpoints in (once per epoch) (default '
'/tmp/tensorflow/mnist/checkpoints/)'))
parser.add_argument(
'--lr',
type=float,
default=0.001,
metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument(
'--noise',
type=int,
default=100,
metavar='N',
help='Length of noise vector for generator input (default: 100)')
parser.add_argument(
'--no-gpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
FLAGS, unparsed = parser.parse_known_args()
2019-07-17 06:10:32 +00:00
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)