# TRAINS - Example of tensorflow eager mode, model logging and tensorboard # # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A deep MNIST classifier using convolutional layers. Sample usage: python mnist.py --help """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import os import sys import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from trains import Task tf.compat.v1.enable_eager_execution() task = Task.init(project_name='examples', task_name='Tensorflow eager mode') FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer('data_num', 100, """Flag of type integer""") tf.app.flags.DEFINE_string('img_path', './img', """Flag of type string""") layers = tf.keras.layers FLAGS = None class Discriminator(tf.keras.Model): """GAN Discriminator. A network to differentiate between generated and real handwritten digits. """ def __init__(self, data_format): """Creates a model for discriminating between real and generated digits. Args: data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is typically faster on GPUs while 'channels_last' is typically faster on CPUs. See https://www.tensorflow.org/performance/performance_guide#data_formats """ super(Discriminator, self).__init__(name='') if data_format == 'channels_first': self._input_shape = [-1, 1, 28, 28] else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] self.conv1 = layers.Conv2D( 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) self.conv2 = layers.Conv2D( 128, 5, data_format=data_format, activation=tf.tanh) self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) self.flatten = layers.Flatten() self.fc1 = layers.Dense(1024, activation=tf.tanh) self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. Users should invoke __call__ to run the network, which delegates to this method (and not call this method directly). Args: inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1] or [batch_size, 1, 28, 28] Returns: A Tensor with shape [batch_size] containing logits estimating the probability that corresponding digit is real. """ x = tf.reshape(inputs, self._input_shape) x = self.conv1(x) x = self.pool1(x) x = self.conv2(x) x = self.pool2(x) x = self.flatten(x) x = self.fc1(x) x = self.fc2(x) return x class Generator(tf.keras.Model): """Generator of handwritten digits similar to the ones in the MNIST dataset. """ def __init__(self, data_format): """Creates a model for discriminating between real and generated digits. Args: data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is typically faster on GPUs while 'channels_last' is typically faster on CPUs. See https://www.tensorflow.org/performance/performance_guide#data_formats """ super(Generator, self).__init__(name='') self.data_format = data_format # We are using 128 6x6 channels as input to the first deconvolution layer if data_format == 'channels_first': self._pre_conv_shape = [-1, 128, 6, 6] else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) self.conv1 = layers.Conv2DTranspose( 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) self.conv2 = layers.Conv2DTranspose( 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): """Return a batch of generated images. Users should invoke __call__ to run the network, which delegates to this method (and not call this method directly). Args: inputs: A batch of noise vectors as a Tensor with shape [batch_size, length of noise vectors]. Returns: A Tensor containing generated images. If data_format is 'channels_last', the shape of returned images is [batch_size, 28, 28, 1], else [batch_size, 1, 28, 28] """ x = self.fc1(inputs) x = tf.reshape(x, shape=self._pre_conv_shape) x = self.conv1(x) x = self.conv2(x) return x def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): """Original discriminator loss for GANs, with label smoothing. See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more details. Args: discriminator_real_outputs: Discriminator output on real data. discriminator_gen_outputs: Discriminator output on generated data. Expected to be in the range of (-inf, inf). Returns: A scalar loss Tensor. """ loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy( tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, label_smoothing=0.25) loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) loss = loss_on_real + loss_on_generated tf.contrib.summary.scalar('discriminator_loss', loss) return loss def generator_loss(discriminator_gen_outputs): """Original generator loss for GANs. L = -log(sigmoid(D(G(z)))) See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more details. Args: discriminator_gen_outputs: Discriminator output on generated data. Expected to be in the range of (-inf, inf). Returns: A scalar loss Tensor. """ loss = tf.compat.v1.losses.sigmoid_cross_entropy( tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) tf.contrib.summary.scalar('generator_loss', loss) return loss def train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, dataset, step_counter, log_interval, noise_dim): """Train `generator` and `discriminator` models on `dataset`. Args: generator: Generator model. discriminator: Discriminator model. generator_optimizer: Optimizer to use for generator. discriminator_optimizer: Optimizer to use for discriminator. dataset: Dataset of images to train on. step_counter: An integer variable, used to write summaries regularly. log_interval: How many steps to wait between logging and collecting summaries. noise_dim: Dimension of noise vector to use. """ total_generator_loss = 0.0 total_discriminator_loss = 0.0 for (batch_index, images) in enumerate(dataset): with tf.device('/cpu:0'): tf.compat.v1.assign_add(step_counter, 1) with tf.contrib.summary.record_summaries_every_n_global_steps( log_interval, global_step=step_counter): current_batch_size = images.shape[0] noise = tf.random.uniform( shape=[current_batch_size, noise_dim], minval=-1., maxval=1., seed=batch_index) # we can use 2 tapes or a single persistent tape. # Using two tapes is memory efficient since intermediate tensors can be # released between the two .gradient() calls below with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', tf.reshape(generated_images, [-1, 28, 28, 1]), max_images=10) discriminator_gen_outputs = discriminator(generated_images) discriminator_real_outputs = discriminator(images) discriminator_loss_val = discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs) total_discriminator_loss += discriminator_loss_val generator_loss_val = generator_loss(discriminator_gen_outputs) total_generator_loss += generator_loss_val generator_grad = gen_tape.gradient(generator_loss_val, generator.variables) discriminator_grad = disc_tape.gradient(discriminator_loss_val, discriminator.variables) generator_optimizer.apply_gradients( zip(generator_grad, generator.variables)) discriminator_optimizer.apply_gradients( zip(discriminator_grad, discriminator.variables)) if log_interval and batch_index > 0 and batch_index % log_interval == 0: print('Batch #%d\tAverage Generator Loss: %.6f\t' 'Average Discriminator Loss: %.6f' % (batch_index, total_generator_loss / batch_index, total_discriminator_loss / batch_index)) def main(_): (device, data_format) = ('/gpu:0', 'channels_first') if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0: (device, data_format) = ('/cpu:0', 'channels_last') print('Using device %s, and data format %s.' % (device, data_format)) # Load the datasets data = input_data.read_data_sets(FLAGS.data_dir) dataset = ( tf.data.Dataset.from_tensor_slices(data.train.images[:1280]).shuffle(60000) .batch(FLAGS.batch_size)) # Create the models and optimizers. model_objects = { 'generator': Generator(data_format), 'discriminator': Discriminator(data_format), 'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr), 'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr), 'step_counter': tf.compat.v1.train.get_or_create_global_step(), } # Prepare summary writer and checkpoint info summary_writer = tf.contrib.summary.create_file_writer( FLAGS.output_dir, flush_millis=1000) checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) checkpoint = tf.train.Checkpoint(**model_objects) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) with tf.device(device): for _ in range(3): start = time.time() with summary_writer.as_default(): train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval, noise_dim=FLAGS.noise, **model_objects) end = time.time() checkpoint.save(checkpoint_prefix) print('\nTrain time for epoch #%d (step %d): %f' % (checkpoint.save_counter.numpy(), checkpoint.step_counter.numpy(), end - start)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--data-dir', type=str, default='/tmp/tensorflow/mnist/input_data', help=('Directory for storing input data (default ' '/tmp/tensorflow/mnist/input_data)')) parser.add_argument( '--batch-size', type=int, default=16, metavar='N', help='input batch size for training (default: 128)') parser.add_argument( '--log-interval', type=int, default=1, metavar='N', help=('number of batches between logging and writing summaries ' '(default: 100)')) parser.add_argument( '--output_dir', type=str, default='/tmp/tensorflow/', metavar='DIR', help='Directory to write TensorBoard summaries (defaults to none)') parser.add_argument( '--checkpoint_dir', type=str, default='/tmp/tensorflow/mnist/checkpoints/', metavar='DIR', help=('Directory to save checkpoints in (once per epoch) (default ' '/tmp/tensorflow/mnist/checkpoints/)')) parser.add_argument( '--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.001)') parser.add_argument( '--noise', type=int, default=100, metavar='N', help='Length of noise vector for generator input (default: 100)') parser.add_argument( '--no-gpu', action='store_true', default=False, help='disables GPU usage even if a GPU is available') FLAGS, unparsed = parser.parse_known_args() tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)