2019-06-10 17:00:28 +00:00
|
|
|
# TRAINS - Example of tensorflow eager mode, model logging and tensorboard
|
|
|
|
#
|
|
|
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
"""A deep MNIST classifier using convolutional layers.
|
|
|
|
Sample usage:
|
|
|
|
python mnist.py --help
|
|
|
|
"""
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.examples.tutorials.mnist import input_data
|
|
|
|
from trains import Task
|
|
|
|
|
2019-07-17 06:10:32 +00:00
|
|
|
tf.compat.v1.enable_eager_execution()
|
2019-06-10 17:00:28 +00:00
|
|
|
|
|
|
|
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
|
|
|
|
|
|
|
|
|
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
|
tf.app.flags.DEFINE_integer('data_num', 100, """Flag of type integer""")
|
|
|
|
tf.app.flags.DEFINE_string('img_path', './img', """Flag of type string""")
|
|
|
|
|
|
|
|
|
|
|
|
layers = tf.keras.layers
|
|
|
|
FLAGS = None
|
|
|
|
|
|
|
|
|
|
|
|
class Discriminator(tf.keras.Model):
|
|
|
|
"""GAN Discriminator.
|
|
|
|
A network to differentiate between generated and real handwritten digits.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_format):
|
|
|
|
"""Creates a model for discriminating between real and generated digits.
|
|
|
|
Args:
|
|
|
|
data_format: Either 'channels_first' or 'channels_last'.
|
|
|
|
'channels_first' is typically faster on GPUs while 'channels_last' is
|
|
|
|
typically faster on CPUs. See
|
|
|
|
https://www.tensorflow.org/performance/performance_guide#data_formats
|
|
|
|
"""
|
|
|
|
super(Discriminator, self).__init__(name='')
|
|
|
|
if data_format == 'channels_first':
|
|
|
|
self._input_shape = [-1, 1, 28, 28]
|
|
|
|
else:
|
|
|
|
assert data_format == 'channels_last'
|
|
|
|
self._input_shape = [-1, 28, 28, 1]
|
|
|
|
self.conv1 = layers.Conv2D(
|
|
|
|
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
|
|
|
|
self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
|
|
|
|
self.conv2 = layers.Conv2D(
|
|
|
|
128, 5, data_format=data_format, activation=tf.tanh)
|
|
|
|
self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
|
|
|
|
self.flatten = layers.Flatten()
|
|
|
|
self.fc1 = layers.Dense(1024, activation=tf.tanh)
|
|
|
|
self.fc2 = layers.Dense(1, activation=None)
|
|
|
|
|
|
|
|
def call(self, inputs):
|
|
|
|
"""Return two logits per image estimating input authenticity.
|
|
|
|
Users should invoke __call__ to run the network, which delegates to this
|
|
|
|
method (and not call this method directly).
|
|
|
|
Args:
|
|
|
|
inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
|
|
|
|
or [batch_size, 1, 28, 28]
|
|
|
|
Returns:
|
|
|
|
A Tensor with shape [batch_size] containing logits estimating
|
|
|
|
the probability that corresponding digit is real.
|
|
|
|
"""
|
|
|
|
x = tf.reshape(inputs, self._input_shape)
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.pool1(x)
|
|
|
|
x = self.conv2(x)
|
|
|
|
x = self.pool2(x)
|
|
|
|
x = self.flatten(x)
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(tf.keras.Model):
|
|
|
|
"""Generator of handwritten digits similar to the ones in the MNIST dataset.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_format):
|
|
|
|
"""Creates a model for discriminating between real and generated digits.
|
|
|
|
Args:
|
|
|
|
data_format: Either 'channels_first' or 'channels_last'.
|
|
|
|
'channels_first' is typically faster on GPUs while 'channels_last' is
|
|
|
|
typically faster on CPUs. See
|
|
|
|
https://www.tensorflow.org/performance/performance_guide#data_formats
|
|
|
|
"""
|
|
|
|
super(Generator, self).__init__(name='')
|
|
|
|
self.data_format = data_format
|
|
|
|
# We are using 128 6x6 channels as input to the first deconvolution layer
|
|
|
|
if data_format == 'channels_first':
|
|
|
|
self._pre_conv_shape = [-1, 128, 6, 6]
|
|
|
|
else:
|
|
|
|
assert data_format == 'channels_last'
|
|
|
|
self._pre_conv_shape = [-1, 6, 6, 128]
|
|
|
|
self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
|
|
|
|
|
|
|
|
# In call(), we reshape the output of fc1 to _pre_conv_shape
|
|
|
|
|
|
|
|
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
|
|
|
|
self.conv1 = layers.Conv2DTranspose(
|
|
|
|
64, 4, strides=2, activation=None, data_format=data_format)
|
|
|
|
|
|
|
|
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
|
|
|
|
self.conv2 = layers.Conv2DTranspose(
|
|
|
|
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
|
|
|
|
|
|
|
|
def call(self, inputs):
|
|
|
|
"""Return a batch of generated images.
|
|
|
|
Users should invoke __call__ to run the network, which delegates to this
|
|
|
|
method (and not call this method directly).
|
|
|
|
Args:
|
|
|
|
inputs: A batch of noise vectors as a Tensor with shape
|
|
|
|
[batch_size, length of noise vectors].
|
|
|
|
Returns:
|
|
|
|
A Tensor containing generated images. If data_format is 'channels_last',
|
|
|
|
the shape of returned images is [batch_size, 28, 28, 1], else
|
|
|
|
[batch_size, 1, 28, 28]
|
|
|
|
"""
|
|
|
|
|
|
|
|
x = self.fc1(inputs)
|
|
|
|
x = tf.reshape(x, shape=self._pre_conv_shape)
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.conv2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
|
|
|
|
"""Original discriminator loss for GANs, with label smoothing.
|
|
|
|
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
|
|
|
|
details.
|
|
|
|
Args:
|
|
|
|
discriminator_real_outputs: Discriminator output on real data.
|
|
|
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
|
|
|
to be in the range of (-inf, inf).
|
|
|
|
Returns:
|
|
|
|
A scalar loss Tensor.
|
|
|
|
"""
|
|
|
|
|
2019-07-17 06:10:32 +00:00
|
|
|
loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
|
2019-06-10 17:00:28 +00:00
|
|
|
tf.ones_like(discriminator_real_outputs),
|
|
|
|
discriminator_real_outputs,
|
|
|
|
label_smoothing=0.25)
|
2019-07-17 06:10:32 +00:00
|
|
|
loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
|
2019-06-10 17:00:28 +00:00
|
|
|
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
|
|
|
loss = loss_on_real + loss_on_generated
|
|
|
|
tf.contrib.summary.scalar('discriminator_loss', loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def generator_loss(discriminator_gen_outputs):
|
|
|
|
"""Original generator loss for GANs.
|
|
|
|
L = -log(sigmoid(D(G(z))))
|
|
|
|
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
|
|
|
|
for more details.
|
|
|
|
Args:
|
|
|
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
|
|
|
to be in the range of (-inf, inf).
|
|
|
|
Returns:
|
|
|
|
A scalar loss Tensor.
|
|
|
|
"""
|
2019-07-17 06:10:32 +00:00
|
|
|
loss = tf.compat.v1.losses.sigmoid_cross_entropy(
|
2019-06-10 17:00:28 +00:00
|
|
|
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
|
|
|
tf.contrib.summary.scalar('generator_loss', loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(generator, discriminator, generator_optimizer,
|
|
|
|
discriminator_optimizer, dataset, step_counter,
|
|
|
|
log_interval, noise_dim):
|
|
|
|
"""Train `generator` and `discriminator` models on `dataset`.
|
|
|
|
Args:
|
|
|
|
generator: Generator model.
|
|
|
|
discriminator: Discriminator model.
|
|
|
|
generator_optimizer: Optimizer to use for generator.
|
|
|
|
discriminator_optimizer: Optimizer to use for discriminator.
|
|
|
|
dataset: Dataset of images to train on.
|
|
|
|
step_counter: An integer variable, used to write summaries regularly.
|
|
|
|
log_interval: How many steps to wait between logging and collecting
|
|
|
|
summaries.
|
|
|
|
noise_dim: Dimension of noise vector to use.
|
|
|
|
"""
|
|
|
|
|
|
|
|
total_generator_loss = 0.0
|
|
|
|
total_discriminator_loss = 0.0
|
|
|
|
for (batch_index, images) in enumerate(dataset):
|
|
|
|
with tf.device('/cpu:0'):
|
2019-07-17 06:10:32 +00:00
|
|
|
tf.compat.v1.assign_add(step_counter, 1)
|
2019-06-10 17:00:28 +00:00
|
|
|
|
|
|
|
with tf.contrib.summary.record_summaries_every_n_global_steps(
|
|
|
|
log_interval, global_step=step_counter):
|
|
|
|
current_batch_size = images.shape[0]
|
2019-07-17 06:10:32 +00:00
|
|
|
noise = tf.random.uniform(
|
2019-06-10 17:00:28 +00:00
|
|
|
shape=[current_batch_size, noise_dim],
|
|
|
|
minval=-1.,
|
|
|
|
maxval=1.,
|
|
|
|
seed=batch_index)
|
|
|
|
|
|
|
|
# we can use 2 tapes or a single persistent tape.
|
|
|
|
# Using two tapes is memory efficient since intermediate tensors can be
|
|
|
|
# released between the two .gradient() calls below
|
|
|
|
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
|
|
|
generated_images = generator(noise)
|
|
|
|
tf.contrib.summary.image(
|
|
|
|
'generated_images',
|
|
|
|
tf.reshape(generated_images, [-1, 28, 28, 1]),
|
|
|
|
max_images=10)
|
|
|
|
|
|
|
|
discriminator_gen_outputs = discriminator(generated_images)
|
|
|
|
discriminator_real_outputs = discriminator(images)
|
|
|
|
discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
|
|
|
|
discriminator_gen_outputs)
|
|
|
|
total_discriminator_loss += discriminator_loss_val
|
|
|
|
|
|
|
|
generator_loss_val = generator_loss(discriminator_gen_outputs)
|
|
|
|
total_generator_loss += generator_loss_val
|
|
|
|
|
|
|
|
generator_grad = gen_tape.gradient(generator_loss_val,
|
|
|
|
generator.variables)
|
|
|
|
discriminator_grad = disc_tape.gradient(discriminator_loss_val,
|
|
|
|
discriminator.variables)
|
|
|
|
|
|
|
|
generator_optimizer.apply_gradients(
|
|
|
|
zip(generator_grad, generator.variables))
|
|
|
|
discriminator_optimizer.apply_gradients(
|
|
|
|
zip(discriminator_grad, discriminator.variables))
|
|
|
|
|
|
|
|
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
|
|
|
|
print('Batch #%d\tAverage Generator Loss: %.6f\t'
|
|
|
|
'Average Discriminator Loss: %.6f' %
|
|
|
|
(batch_index, total_generator_loss / batch_index,
|
|
|
|
total_discriminator_loss / batch_index))
|
|
|
|
|
|
|
|
|
|
|
|
def main(_):
|
|
|
|
(device, data_format) = ('/gpu:0', 'channels_first')
|
|
|
|
if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
|
|
|
|
(device, data_format) = ('/cpu:0', 'channels_last')
|
|
|
|
print('Using device %s, and data format %s.' % (device, data_format))
|
|
|
|
|
|
|
|
# Load the datasets
|
|
|
|
data = input_data.read_data_sets(FLAGS.data_dir)
|
|
|
|
dataset = (
|
|
|
|
tf.data.Dataset.from_tensor_slices(data.train.images[:1280]).shuffle(60000)
|
|
|
|
.batch(FLAGS.batch_size))
|
|
|
|
|
|
|
|
# Create the models and optimizers.
|
|
|
|
model_objects = {
|
|
|
|
'generator': Generator(data_format),
|
|
|
|
'discriminator': Discriminator(data_format),
|
2019-07-17 06:10:32 +00:00
|
|
|
'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
|
|
|
|
'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
|
|
|
|
'step_counter': tf.compat.v1.train.get_or_create_global_step(),
|
2019-06-10 17:00:28 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
# Prepare summary writer and checkpoint info
|
|
|
|
summary_writer = tf.contrib.summary.create_file_writer(
|
|
|
|
FLAGS.output_dir, flush_millis=1000)
|
|
|
|
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
|
|
|
|
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
|
|
|
|
if latest_cpkt:
|
|
|
|
print('Using latest checkpoint at ' + latest_cpkt)
|
|
|
|
checkpoint = tf.train.Checkpoint(**model_objects)
|
|
|
|
# Restore variables on creation if a checkpoint exists.
|
|
|
|
checkpoint.restore(latest_cpkt)
|
|
|
|
|
|
|
|
with tf.device(device):
|
|
|
|
for _ in range(3):
|
|
|
|
start = time.time()
|
|
|
|
with summary_writer.as_default():
|
|
|
|
train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
|
|
|
|
noise_dim=FLAGS.noise, **model_objects)
|
|
|
|
end = time.time()
|
|
|
|
checkpoint.save(checkpoint_prefix)
|
|
|
|
print('\nTrain time for epoch #%d (step %d): %f' %
|
|
|
|
(checkpoint.save_counter.numpy(),
|
|
|
|
checkpoint.step_counter.numpy(),
|
|
|
|
end - start))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
'--data-dir',
|
|
|
|
type=str,
|
|
|
|
default='/tmp/tensorflow/mnist/input_data',
|
|
|
|
help=('Directory for storing input data (default '
|
|
|
|
'/tmp/tensorflow/mnist/input_data)'))
|
|
|
|
parser.add_argument(
|
|
|
|
'--batch-size',
|
|
|
|
type=int,
|
|
|
|
default=16,
|
|
|
|
metavar='N',
|
|
|
|
help='input batch size for training (default: 128)')
|
|
|
|
parser.add_argument(
|
|
|
|
'--log-interval',
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
metavar='N',
|
|
|
|
help=('number of batches between logging and writing summaries '
|
|
|
|
'(default: 100)'))
|
|
|
|
parser.add_argument(
|
|
|
|
'--output_dir',
|
|
|
|
type=str,
|
|
|
|
default='/tmp/tensorflow/',
|
|
|
|
metavar='DIR',
|
|
|
|
help='Directory to write TensorBoard summaries (defaults to none)')
|
|
|
|
parser.add_argument(
|
|
|
|
'--checkpoint_dir',
|
|
|
|
type=str,
|
|
|
|
default='/tmp/tensorflow/mnist/checkpoints/',
|
|
|
|
metavar='DIR',
|
|
|
|
help=('Directory to save checkpoints in (once per epoch) (default '
|
|
|
|
'/tmp/tensorflow/mnist/checkpoints/)'))
|
|
|
|
parser.add_argument(
|
|
|
|
'--lr',
|
|
|
|
type=float,
|
|
|
|
default=0.001,
|
|
|
|
metavar='LR',
|
|
|
|
help='learning rate (default: 0.001)')
|
|
|
|
parser.add_argument(
|
|
|
|
'--noise',
|
|
|
|
type=int,
|
|
|
|
default=100,
|
|
|
|
metavar='N',
|
|
|
|
help='Length of noise vector for generator input (default: 100)')
|
|
|
|
parser.add_argument(
|
|
|
|
'--no-gpu',
|
|
|
|
action='store_true',
|
|
|
|
default=False,
|
|
|
|
help='disables GPU usage even if a GPU is available')
|
|
|
|
|
|
|
|
FLAGS, unparsed = parser.parse_known_args()
|
|
|
|
|
2019-07-17 06:10:32 +00:00
|
|
|
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|