From 8763a884da02758bf995ac2e7702915a3dffbfe7 Mon Sep 17 00:00:00 2001 From: Rizwan Hasan Date: Thu, 21 Apr 2022 13:18:30 +0600 Subject: [PATCH] Update Pytorch Lightning example for pytorch-lightning>=v1.6.0 (#650) --- .../pytorch-lightning/pytorch_lightning_example.py | 9 +++++---- examples/frameworks/pytorch-lightning/requirements.txt | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/frameworks/pytorch-lightning/pytorch_lightning_example.py b/examples/frameworks/pytorch-lightning/pytorch_lightning_example.py index ff111f2a..91f2d085 100644 --- a/examples/frameworks/pytorch-lightning/pytorch_lightning_example.py +++ b/examples/frameworks/pytorch-lightning/pytorch_lightning_example.py @@ -1,3 +1,4 @@ +import os from argparse import ArgumentParser import torch import pytorch_lightning as pl @@ -73,9 +74,9 @@ if __name__ == '__main__': mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) mnist_train, mnist_val = random_split(dataset, [55000, 5000]) - train_loader = DataLoader(mnist_train, batch_size=args.batch_size) - val_loader = DataLoader(mnist_val, batch_size=args.batch_size) - test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + train_loader = DataLoader(mnist_train, batch_size=args.batch_size, num_workers=os.cpu_count()) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size, num_workers=os.cpu_count()) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size, num_workers=os.cpu_count()) # ------------ # model @@ -91,4 +92,4 @@ if __name__ == '__main__': # ------------ # testing # ------------ - trainer.test(test_dataloaders=test_loader) + trainer.test(dataloaders=test_loader) diff --git a/examples/frameworks/pytorch-lightning/requirements.txt b/examples/frameworks/pytorch-lightning/requirements.txt index 9fed39e4..62e05989 100644 --- a/examples/frameworks/pytorch-lightning/requirements.txt +++ b/examples/frameworks/pytorch-lightning/requirements.txt @@ -1,4 +1,4 @@ clearml -pytorch_lightning >= 1.1.2 +pytorch-lightning >= 1.6.0 torch torchvision