diff --git a/examples/frameworks/pytorch/pytorch_tensorboard.py b/examples/frameworks/pytorch/pytorch_tensorboard.py index 8392f67d..ece50a7b 100644 --- a/examples/frameworks/pytorch/pytorch_tensorboard.py +++ b/examples/frameworks/pytorch/pytorch_tensorboard.py @@ -118,7 +118,7 @@ def main(): transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.batch_size, shuffle=True, **kwargs) + batch_size=args.test_batch_size, shuffle=True, **kwargs) model = Net() if args.cuda: diff --git a/examples/frameworks/tensorboardx/pytorch_tensorboardX.py b/examples/frameworks/tensorboardx/pytorch_tensorboardX.py index f995e6e6..b0a1901f 100644 --- a/examples/frameworks/tensorboardx/pytorch_tensorboardX.py +++ b/examples/frameworks/tensorboardx/pytorch_tensorboardX.py @@ -118,7 +118,7 @@ def main(): transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.batch_size, shuffle=True, **kwargs) + batch_size=args.test_batch_size, shuffle=True, **kwargs) model = Net() if args.cuda: