mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix Windows support
This commit is contained in:
parent
10b3c7174c
commit
d8e54b466c
@ -9,13 +9,14 @@
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from os.path import exists
|
||||
from os.path import exists, join
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from trains import Task
|
||||
|
||||
MODEL_PATH = "/tmp/module_no_signatures"
|
||||
MODEL_PATH = join(tempfile.gettempdir(), "module_no_signatures")
|
||||
task = Task.init(project_name='examples', task_name='Tensorflow mnist example')
|
||||
|
||||
## block
|
||||
@ -39,14 +40,14 @@ with tf.Session(graph=tf.Graph()) as sess:
|
||||
|
||||
# Import MNIST data
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
|
||||
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
|
||||
|
||||
# Parameters
|
||||
parameters = {
|
||||
'learning_rate': 0.001,
|
||||
'batch_size': 100,
|
||||
'display_step': 1,
|
||||
'model_path': "/tmp/model.ckpt",
|
||||
'model_path': join(tempfile.gettempdir(), "model.ckpt"),
|
||||
|
||||
# Network Parameters
|
||||
'n_hidden_1': 256, # 1st layer number of features
|
||||
|
@ -27,7 +27,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
@ -139,8 +139,8 @@ def train():
|
||||
# Merge all the summaries and write them out to
|
||||
# /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
|
||||
merged = tf.summary.merge_all()
|
||||
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
|
||||
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
|
||||
train_writer = tf.summary.FileWriter(os.path.join(FLAGS.log_dir, 'train'), sess.graph)
|
||||
test_writer = tf.summary.FileWriter(os.path.join(FLAGS.log_dir, 'test'))
|
||||
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
@ -205,13 +205,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--dropout', type=float, default=0.9,
|
||||
help='Keep probability for training dropout.')
|
||||
parser.add_argument('--data_dir', type=str,
|
||||
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data'),
|
||||
default=os.path.join(tempfile.gettempdir(), 'tensorflow', 'mnist', 'input_data'),
|
||||
help='Directory for storing input data')
|
||||
parser.add_argument('--log_dir', type=str,
|
||||
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
|
||||
'tensorflow/mnist/logs/mnist_with_summaries'),
|
||||
default=os.path.join(tempfile.gettempdir(),
|
||||
'tensorflow', 'mnist', 'logs', 'mnist_with_summaries'),
|
||||
help='Summaries log directory')
|
||||
parser.add_argument('--save_path', default="/tmp/model.ckpt",
|
||||
parser.add_argument('--save_path', default=os.path.join(tempfile.gettempdir(), "model.ckpt"),
|
||||
help='Save the trained model under this path')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
Loading…
Reference in New Issue
Block a user