Fix Windows support

This commit is contained in:
allegroai 2019-11-08 22:14:42 +02:00
parent 10b3c7174c
commit d8e54b466c
2 changed files with 12 additions and 11 deletions

View File

@ -9,13 +9,14 @@
from __future__ import print_function
from os.path import exists
from os.path import exists, join
import tempfile
import numpy as np
import tensorflow as tf
from trains import Task
MODEL_PATH = "/tmp/module_no_signatures"
MODEL_PATH = join(tempfile.gettempdir(), "module_no_signatures")
task = Task.init(project_name='examples', task_name='Tensorflow mnist example')
## block
@ -39,14 +40,14 @@ with tf.Session(graph=tf.Graph()) as sess:
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# Parameters
parameters = {
'learning_rate': 0.001,
'batch_size': 100,
'display_step': 1,
'model_path': "/tmp/model.ckpt",
'model_path': join(tempfile.gettempdir(), "model.ckpt"),
# Network Parameters
'n_hidden_1': 256, # 1st layer number of features

View File

@ -27,7 +27,7 @@ from __future__ import print_function
import argparse
import os
import sys
import tempfile
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
@ -139,8 +139,8 @@ def train():
# Merge all the summaries and write them out to
# /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
train_writer = tf.summary.FileWriter(os.path.join(FLAGS.log_dir, 'train'), sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(FLAGS.log_dir, 'test'))
tf.global_variables_initializer().run()
@ -205,13 +205,13 @@ if __name__ == '__main__':
parser.add_argument('--dropout', type=float, default=0.9,
help='Keep probability for training dropout.')
parser.add_argument('--data_dir', type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data'),
default=os.path.join(tempfile.gettempdir(), 'tensorflow', 'mnist', 'input_data'),
help='Directory for storing input data')
parser.add_argument('--log_dir', type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
'tensorflow/mnist/logs/mnist_with_summaries'),
default=os.path.join(tempfile.gettempdir(),
'tensorflow', 'mnist', 'logs', 'mnist_with_summaries'),
help='Summaries log directory')
parser.add_argument('--save_path', default="/tmp/model.ckpt",
parser.add_argument('--save_path', default=os.path.join(tempfile.gettempdir(), "model.ckpt"),
help='Save the trained model under this path')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)