diff --git a/examples/execute_jupyter_notebook_server.py b/examples/execute_jupyter_notebook_server.py index ff66021a..3b5cfce3 100644 --- a/examples/execute_jupyter_notebook_server.py +++ b/examples/execute_jupyter_notebook_server.py @@ -4,10 +4,14 @@ import subprocess from copy import deepcopy import socket from tempfile import mkstemp +# make sure we have jupter in the auto requirements import jupyter from trains import Task +# set default docker image, with network configuration +os.environ['TRAINS_DOCKER_IMAGE'] = 'nvidia/cuda --network host' + # initialize TRAINS task = Task.init(project_name='examples', task_name='Remote Jupyter NoteBook') @@ -21,33 +25,45 @@ for key in os.environ: if key.startswith('TRAINS') and key not in preserve: env.pop(key, None) +# Add jupyter server base folder +param = { + 'jupyter_server_base_directory': '', +} +task.connect(param) + # execute jupyter notebook fd, local_filename = mkstemp() -print('Running Jupyter Notebook Server on {} [{}]'.format(socket.gethostname(), socket.gethostbyname(socket.gethostname()))) -process = subprocess.Popen([sys.executable, '-m', 'jupyter', 'notebook'], env=env, stdout=fd, stderr=fd) +cwd = os.path.expandvars(os.path.expanduser(param['jupyter_server_base_directory'])) \ + if param['jupyter_server_base_directory'] else os.getcwd() +print('Running Jupyter Notebook Server on {} [{}] at {}'.format(socket.gethostname(), + socket.gethostbyname(socket.gethostname()), cwd)) +process = subprocess.Popen([sys.executable, '-m', 'jupyter', 'notebook', '--no-browser', '--allow-root'], + env=env, stdout=fd, stderr=fd, cwd=cwd) # print stdout/stderr prev_line_count = 0 -while True: +process_running = True +while process_running: + process_running = False try: process.wait(timeout=2.0 if prev_line_count == 0 else 15.0) except subprocess.TimeoutExpired: - with open(local_filename, "rt") as f: - # read new lines - new_lines = f.readlines() - if not new_lines: - continue - output = ''.join(new_lines) - print(output) - # update task comment with jupyter notebook server links - if prev_line_count == 0: - task.comment += '\n' + ''.join(line for line in new_lines if 'http://' in line or 'https://' in line) - prev_line_count += len(new_lines) + process_running = True - os.lseek(fd, 0, 0) - os.ftruncate(fd, 0) - continue - break + with open(local_filename, "rt") as f: + # read new lines + new_lines = f.readlines() + if not new_lines: + continue + output = ''.join(new_lines) + print(output) + # update task comment with jupyter notebook server links + if prev_line_count == 0: + task.comment += '\n' + ''.join(line for line in new_lines if 'http://' in line or 'https://' in line) + prev_line_count += len(new_lines) + + os.lseek(fd, 0, 0) + os.ftruncate(fd, 0) # cleanup os.close(fd) diff --git a/examples/tensorflow_mnist_with_summaries.py b/examples/tensorflow_mnist_with_summaries.py index 51800539..81b2b879 100644 --- a/examples/tensorflow_mnist_with_summaries.py +++ b/examples/tensorflow_mnist_with_summaries.py @@ -151,7 +151,7 @@ def train(): def feed_dict(train): """Make a TensorFlow feed_dict: maps data onto Tensor placeholders.""" if train or FLAGS.fake_data: - xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data) + xs, ys = mnist.train.next_batch(FLAGS.batch_size, fake_data=FLAGS.fake_data) k = FLAGS.dropout else: xs, ys = mnist.test.images, mnist.test.labels @@ -165,7 +165,7 @@ def train(): test_writer.add_summary(summary, i) print('Accuracy at step %s: %s' % (i, acc)) else: # Record train set summaries, and train - if i % 100 == 99: # Record execution stats + if i % FLAGS.batch_size == FLAGS.batch_size - 1: # Record execution stats run_metadata = tf.RunMetadata() summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True), @@ -213,5 +213,7 @@ if __name__ == '__main__': help='Summaries log directory') parser.add_argument('--save_path', default=os.path.join(tempfile.gettempdir(), "model.ckpt"), help='Save the trained model under this path') + parser.add_argument('--batch_size', default=100, + help='Batch size for training') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)