Update Pytorch Lightning example for pytorch-lightning>=v1.6.0 (#650)

This commit is contained in:
Rizwan Hasan 2022-04-21 13:18:30 +06:00 committed by GitHub
parent 90d060dd7e
commit 8763a884da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
@ -73,9 +74,9 @@ if __name__ == '__main__':
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000]) mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size) train_loader = DataLoader(mnist_train, batch_size=args.batch_size, num_workers=os.cpu_count())
val_loader = DataLoader(mnist_val, batch_size=args.batch_size) val_loader = DataLoader(mnist_val, batch_size=args.batch_size, num_workers=os.cpu_count())
test_loader = DataLoader(mnist_test, batch_size=args.batch_size) test_loader = DataLoader(mnist_test, batch_size=args.batch_size, num_workers=os.cpu_count())
# ------------ # ------------
# model # model
@ -91,4 +92,4 @@ if __name__ == '__main__':
# ------------ # ------------
# testing # testing
# ------------ # ------------
trainer.test(test_dataloaders=test_loader) trainer.test(dataloaders=test_loader)

View File

@ -1,4 +1,4 @@
clearml clearml
pytorch_lightning >= 1.1.2 pytorch-lightning >= 1.6.0
torch torch
torchvision torchvision