Fixed requirements, refactored and formatted code in some examples (#567)

This commit is contained in:
Rizwan Hasan 2022-02-04 01:34:42 +06:00 committed by GitHub
parent 172c3e44f1
commit eb5350f551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 39 deletions

View File

@ -1,23 +1,28 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import sys
from tempfile import gettempdir
import numpy as np
import megengine as mge
import megengine.module as M
import megengine.functional as F
from megengine.optimizer import SGD
from megengine.autodiff import GradManager
try:
import megengine as mge
import megengine.functional as F
import megengine.module as M
from megengine.autodiff import GradManager
from megengine.data import DataLoader, RandomSampler
from megengine.data.dataset import MNIST
from megengine.data.transform import Compose, Normalize, Pad, ToMode
from megengine.optimizer import SGD
except ImportError:
raise ImportError(
"megengine package is missing, you can install it using pip: pip install megengine"
if sys.version_info.minor <= 8
else "MegEngine does not support python version >= 3.9"
)
from megengine.data import DataLoader, RandomSampler
from megengine.data.transform import ToMode, Pad, Normalize, Compose
from megengine.data.dataset import MNIST
from tensorboardX import SummaryWriter
from clearml import Task
from tensorboardX import SummaryWriter
class Net(M.Module):
@ -55,11 +60,7 @@ def build_dataloader():
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
dataloader = DataLoader(
train_dataset,
transform=Compose([
Normalize(mean=0.1307*255, std=0.3081*255),
Pad(2),
ToMode('CHW'),
]),
transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]),
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
)
return dataloader
@ -69,10 +70,7 @@ def train(dataloader, args):
writer = SummaryWriter("runs")
net = Net()
net.train()
optimizer = SGD(
net.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd
)
optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
gm = GradManager().attach(net.parameters())
epoch_length = len(dataloader)
@ -90,28 +88,24 @@ def train(dataloader, args):
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
if (epoch + 1) % 5 == 0:
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
mge.save(
net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
) # noqa
def main():
task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa
task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
parser = argparse.ArgumentParser(description="MegEngine MNIST Example")
parser.add_argument(
'--epoch', type=int, default=10,
help='number of training epoch(default: 10)',
"--epoch", type=int, default=10, help="number of training epoch(default: 10)",
)
parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)")
parser.add_argument(
"--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)",
)
parser.add_argument(
'--lr', type=float, default=0.01,
help='learning rate(default: 0.01)'
)
parser.add_argument(
'--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)',
)
parser.add_argument(
'--wd', type=float, default=5e-4,
help='SGD weight decay(default: 5e-4)',
"--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)",
)
args = parser.parse_args()

View File

@ -1,3 +1,3 @@
MegEngine
MegEngine ; python_version < '3.9'
tensorboardX
clearml
clearml

View File

@ -3,4 +3,5 @@ tensorboardX
tensorboard>=1.14.0
torch>=1.1.0
torchvision>=0.3.0
tqdm
clearml