Update Pytorch Lightning example for pytorch-lightning>=v1.6.0 (#650)

This commit is contained in:
Rizwan Hasan 2022-04-21 13:18:30 +06:00 committed by GitHub
parent 90d060dd7e
commit 8763a884da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
@ -73,9 +74,9 @@ if __name__ == '__main__':
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
train_loader = DataLoader(mnist_train, batch_size=args.batch_size, num_workers=os.cpu_count())
val_loader = DataLoader(mnist_val, batch_size=args.batch_size, num_workers=os.cpu_count())
test_loader = DataLoader(mnist_test, batch_size=args.batch_size, num_workers=os.cpu_count())
# ------------
# model
@ -91,4 +92,4 @@ if __name__ == '__main__':
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
trainer.test(dataloaders=test_loader)

View File

@ -1,4 +1,4 @@
clearml
pytorch_lightning >= 1.1.2
pytorch-lightning >= 1.6.0
torch
torchvision