diff --git a/examples/frameworks/pytorch/pytorch_tensorboard.py b/examples/frameworks/pytorch/pytorch_tensorboard.py index cd3cbfa8..b7016cc9 100644 --- a/examples/frameworks/pytorch/pytorch_tensorboard.py +++ b/examples/frameworks/pytorch/pytorch_tensorboard.py @@ -131,7 +131,9 @@ def main(): for epoch in range(1, args.epochs + 1): train(model, epoch, train_loader, args, optimizer, writer) - torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch))) + m = torch.jit.script(model) + m.save(os.path.join(gettempdir(), 'model{}'.format(epoch))) + #torch.save(model, os.path.join(gettempdir(), 'model{}'.format(epoch))) test(model, test_loader, args, optimizer, writer)