From 783b0b99c99d19fb15dac6e122c0d7ded5de8ef7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 14 Jun 2019 01:50:33 +0300 Subject: [PATCH] pep8 --- examples/tensorflow_mnist_with_summaries.py | 73 ++++++++++----------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/examples/tensorflow_mnist_with_summaries.py b/examples/tensorflow_mnist_with_summaries.py index 29060e7b..88dcfca6 100644 --- a/examples/tensorflow_mnist_with_summaries.py +++ b/examples/tensorflow_mnist_with_summaries.py @@ -34,13 +34,12 @@ from tensorflow.examples.tutorials.mnist import input_data from trains import Task FLAGS = None -task = Task.init(project_name='examples', task_name='Tensorflow mnist with summaries example') +task = Task.init(project_name='examples', task_name='Tensorflow mnist with summaries') def train(): # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, - fake_data=FLAGS.fake_data) + mnist = input_data.read_data_sets(FLAGS.data_dir, fake_data=FLAGS.fake_data) sess = tf.InteractiveSession() # Create a multilayer model. @@ -123,12 +122,12 @@ def train(): # the batch. with tf.name_scope('total'): cross_entropy = tf.losses.sparse_softmax_cross_entropy( - labels=y_, logits=y) + labels=y_, logits=y) tf.summary.scalar('cross_entropy', cross_entropy) with tf.name_scope('train'): train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize( - cross_entropy) + cross_entropy) with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): @@ -142,7 +141,7 @@ def train(): merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') - + tf.global_variables_initializer().run() # Train the model, and also write summaries. @@ -161,29 +160,31 @@ def train(): saver = tf.train.Saver() for i in range(FLAGS.max_steps): - if i % 10 == 0: # Record summaries and test-set accuracy + if i % 10 == 0: # Record summaries and test-set accuracy summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) test_writer.add_summary(summary, i) print('Accuracy at step %s: %s' % (i, acc)) - else: # Record train set summaries, and train - if i % 100 == 99: # Record execution stats + else: # Record train set summaries, and train + if i % 100 == 99: # Record execution stats run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, _ = sess.run([merged, train_step], - feed_dict=feed_dict(True), - options=run_options, - run_metadata=run_metadata) - train_writer.add_run_metadata(run_metadata, 'step%03d' % i) + feed_dict=feed_dict(True), + options=run_options, + run_metadata=run_metadata) + train_writer.add_run_metadata(run_metadata, 'step%04d' % i) train_writer.add_summary(summary, i) print('Adding run metadata for', i) - else: # Record a summary + else: # Record a summary summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) - train_writer.add_summary(summary, i) + # train_writer.add_summary(summary, i) - save_path = saver.save(sess,FLAGS.save_path) - print("Model saved in path: %s" % save_path) + save_path = saver.save(sess, FLAGS.save_path) + print("Saved model: %s" % save_path) + print('Flushing all images, this may take a couple of minutes') train_writer.close() test_writer.close() + print('Finished storing all metrics & images') def main(_): @@ -197,30 +198,22 @@ def main(_): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--fake_data', nargs='?', const=True, type=bool, - default=False, - help='If true, uses fake data for unit testing.') - parser.add_argument('--max_steps', type=int, default=1000, - help='Number of steps to run trainer.') + default=False, + help='If true, uses fake data for unit testing.') + parser.add_argument('--max_steps', type=int, default=300, + help='Number of steps to run trainer.') parser.add_argument('--learning_rate', type=float, default=0.001, - help='Initial learning rate') + help='Initial learning rate') parser.add_argument('--dropout', type=float, default=0.9, - help='Keep probability for training dropout.') - parser.add_argument( - '--data_dir', - type=str, - default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), - 'tensorflow/mnist/input_data'), - help='Directory for storing input data') - parser.add_argument( - '--log_dir', - type=str, - default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), - 'tensorflow/mnist/logs/mnist_with_summaries'), - help='Summaries log directory') - parser.add_argument( - '--save_path', - default="/tmp/model.ckpt", - help='Save the trained model under this path' - ) + help='Keep probability for training dropout.') + parser.add_argument('--data_dir', type=str, + default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data'), + help='Directory for storing input data') + parser.add_argument('--log_dir', type=str, + default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), + 'tensorflow/mnist/logs/mnist_with_summaries'), + help='Summaries log directory') + parser.add_argument('--save_path', default="/tmp/model.ckpt", + help='Save the trained model under this path') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)