Fix examples windows support

This commit is contained in:
allegroai 2019-10-10 20:33:58 +03:00
parent aedd3fc87e
commit c0cfe3ccb2
6 changed files with 24 additions and 12 deletions

View File

@ -1,5 +1,6 @@
# TRAINS - Example of manual graphs and statistics reporting # TRAINS - Example of manual graphs and statistics reporting
# #
from PIL import Image
import numpy as np import numpy as np
import logging import logging
from trains import Task from trains import Task
@ -49,11 +50,12 @@ logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter
# reporting images # reporting images
m = np.eye(256, 256, dtype=np.float) m = np.eye(256, 256, dtype=np.float)
logger.report_image("test case", "image float", iteration=1, matrix=m) logger.report_image("test case", "image float", iteration=1, image=m)
m = np.eye(256, 256, dtype=np.uint8)*255 m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image("test case", "image uint8", iteration=1, matrix=m) logger.report_image("test case", "image uint8", iteration=1, image=m)
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2) m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
logger.report_image("test case", "image color red", iteration=1, matrix=m) logger.report_image("test case", "image color red", iteration=1, image=m)
image_open = Image.open('./samples/picasso.jpg')
logger.report_image("test case", "image PIL", iteration=1, image=image_open)
# flush reports (otherwise it will be flushed in the background, every couple of seconds) # flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush() logger.flush()

View File

@ -2,6 +2,9 @@
# #
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from tempfile import gettempdir
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -117,7 +120,7 @@ def main():
test(args, model, device, test_loader) test(args, model, device, test_loader)
if (args.save_model): if (args.save_model):
torch.save(model.state_dict(), "/tmp/mnist_cnn.pt") torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,6 +3,9 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from tempfile import gettempdir
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -122,5 +125,5 @@ def test():
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
train(epoch) train(epoch)
torch.save(model, '/tmp/model{}'.format(epoch)) torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch)))
test() test()

View File

@ -3,6 +3,9 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from tempfile import gettempdir
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -122,5 +125,5 @@ def test():
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
train(epoch) train(epoch)
torch.save(model, '/tmp/model{}'.format(epoch)) torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch)))
test() test()

View File

@ -2,7 +2,7 @@ absl-py>=0.7.1
Keras>=2.2.4 Keras>=2.2.4
joblib>=0.13.2 joblib>=0.13.2
matplotlib>=3.1.1 ; python_version >= '3.6' matplotlib>=3.1.1 ; python_version >= '3.6'
matplotlib == 3.0.3 ; python_version < '3.6' matplotlib >= 2.2.4 ; python_version < '3.6'
seaborn>=0.9.0 seaborn>=0.9.0
sklearn>=0.0 sklearn>=0.0
tensorboard>=1.14.0 tensorboard>=1.14.0
@ -10,7 +10,7 @@ tensorboardX>=1.8
tensorflow>=1.14.0 tensorflow>=1.14.0
torch>=1.1.0 torch>=1.1.0
torchvision>=0.3.0 torchvision>=0.3.0
xgboost>=0.90 xgboost>=0.90 ; python_version >= '3'
xgboost >= 0.82 ; python_version < '3'
# sudo apt-get install graphviz # sudo apt-get install graphviz
graphviz>=0.8 graphviz>=0.8

View File

@ -30,6 +30,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os.path import os.path
from tempfile import gettempdir
from absl import app from absl import app
from absl import flags from absl import flags
@ -42,8 +43,8 @@ task = Task.init(project_name='examples', task_name='tensorboard pr_curve')
tf.compat.v1.disable_v2_behavior() tf.compat.v1.disable_v2_behavior()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('logdir', os.path.join(gettempdir(), "pr_curve_demo"),
flags.DEFINE_string('logdir', '/tmp/pr_curve_demo', 'Directory into which to write TensorBoard data.') "Directory into which to write TensorBoard data.")
flags.DEFINE_integer('steps', 10, flags.DEFINE_integer('steps', 10,
'Number of steps to generate for each PR curve.') 'Number of steps to generate for each PR curve.')