mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Update Pytorch Lightning example for pytorch-lightning>=v1.6.0 (#650)
This commit is contained in:
parent
90d060dd7e
commit
8763a884da
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
@ -73,9 +74,9 @@ if __name__ == '__main__':
|
||||
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
|
||||
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
|
||||
|
||||
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
|
||||
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
|
||||
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
|
||||
train_loader = DataLoader(mnist_train, batch_size=args.batch_size, num_workers=os.cpu_count())
|
||||
val_loader = DataLoader(mnist_val, batch_size=args.batch_size, num_workers=os.cpu_count())
|
||||
test_loader = DataLoader(mnist_test, batch_size=args.batch_size, num_workers=os.cpu_count())
|
||||
|
||||
# ------------
|
||||
# model
|
||||
@ -91,4 +92,4 @@ if __name__ == '__main__':
|
||||
# ------------
|
||||
# testing
|
||||
# ------------
|
||||
trainer.test(test_dataloaders=test_loader)
|
||||
trainer.test(dataloaders=test_loader)
|
||||
|
@ -1,4 +1,4 @@
|
||||
clearml
|
||||
pytorch_lightning >= 1.1.2
|
||||
pytorch-lightning >= 1.6.0
|
||||
torch
|
||||
torchvision
|
||||
|
Loading…
Reference in New Issue
Block a user