mirror of
https://github.com/clearml/clearml
synced 2025-03-09 21:40:51 +00:00
Fix examples windows support
This commit is contained in:
parent
aedd3fc87e
commit
c0cfe3ccb2
@ -1,5 +1,6 @@
|
||||
# TRAINS - Example of manual graphs and statistics reporting
|
||||
#
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import logging
|
||||
from trains import Task
|
||||
@ -49,11 +50,12 @@ logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter
|
||||
|
||||
# reporting images
|
||||
m = np.eye(256, 256, dtype=np.float)
|
||||
logger.report_image("test case", "image float", iteration=1, matrix=m)
|
||||
logger.report_image("test case", "image float", iteration=1, image=m)
|
||||
m = np.eye(256, 256, dtype=np.uint8)*255
|
||||
logger.report_image("test case", "image uint8", iteration=1, matrix=m)
|
||||
logger.report_image("test case", "image uint8", iteration=1, image=m)
|
||||
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
|
||||
logger.report_image("test case", "image color red", iteration=1, matrix=m)
|
||||
|
||||
logger.report_image("test case", "image color red", iteration=1, image=m)
|
||||
image_open = Image.open('./samples/picasso.jpg')
|
||||
logger.report_image("test case", "image PIL", iteration=1, image=image_open)
|
||||
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
|
||||
logger.flush()
|
||||
|
@ -2,6 +2,9 @@
|
||||
#
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -117,7 +120,7 @@ def main():
|
||||
test(args, model, device, test_loader)
|
||||
|
||||
if (args.save_model):
|
||||
torch.save(model.state_dict(), "/tmp/mnist_cnn.pt")
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -3,6 +3,9 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -122,5 +125,5 @@ def test():
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
torch.save(model, '/tmp/model{}'.format(epoch))
|
||||
torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch)))
|
||||
test()
|
||||
|
@ -3,6 +3,9 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -122,5 +125,5 @@ def test():
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
torch.save(model, '/tmp/model{}'.format(epoch))
|
||||
torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch)))
|
||||
test()
|
||||
|
@ -2,7 +2,7 @@ absl-py>=0.7.1
|
||||
Keras>=2.2.4
|
||||
joblib>=0.13.2
|
||||
matplotlib>=3.1.1 ; python_version >= '3.6'
|
||||
matplotlib == 3.0.3 ; python_version < '3.6'
|
||||
matplotlib >= 2.2.4 ; python_version < '3.6'
|
||||
seaborn>=0.9.0
|
||||
sklearn>=0.0
|
||||
tensorboard>=1.14.0
|
||||
@ -10,7 +10,7 @@ tensorboardX>=1.8
|
||||
tensorflow>=1.14.0
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
xgboost>=0.90
|
||||
|
||||
xgboost>=0.90 ; python_version >= '3'
|
||||
xgboost >= 0.82 ; python_version < '3'
|
||||
# sudo apt-get install graphviz
|
||||
graphviz>=0.8
|
||||
|
@ -30,6 +30,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
from tempfile import gettempdir
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
@ -42,8 +43,8 @@ task = Task.init(project_name='examples', task_name='tensorboard pr_curve')
|
||||
|
||||
tf.compat.v1.disable_v2_behavior()
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('logdir', '/tmp/pr_curve_demo', 'Directory into which to write TensorBoard data.')
|
||||
flags.DEFINE_string('logdir', os.path.join(gettempdir(), "pr_curve_demo"),
|
||||
"Directory into which to write TensorBoard data.")
|
||||
|
||||
flags.DEFINE_integer('steps', 10,
|
||||
'Number of steps to generate for each PR curve.')
|
||||
|
Loading…
Reference in New Issue
Block a user