Fixed requirements, refactored and formatted code in some examples (#567)

This commit is contained in:
Rizwan Hasan 2022-02-04 01:34:42 +06:00 committed by GitHub
parent 172c3e44f1
commit eb5350f551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 39 deletions

View File

@ -1,23 +1,28 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse import argparse
import os import os
import sys
from tempfile import gettempdir from tempfile import gettempdir
import numpy as np import numpy as np
import megengine as mge try:
import megengine.module as M import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine.optimizer import SGD import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.data import DataLoader, RandomSampler
from megengine.data.dataset import MNIST
from megengine.data.transform import Compose, Normalize, Pad, ToMode
from megengine.optimizer import SGD
except ImportError:
raise ImportError(
"megengine package is missing, you can install it using pip: pip install megengine"
if sys.version_info.minor <= 8
else "MegEngine does not support python version >= 3.9"
)
from megengine.data import DataLoader, RandomSampler
from megengine.data.transform import ToMode, Pad, Normalize, Compose
from megengine.data.dataset import MNIST
from tensorboardX import SummaryWriter
from clearml import Task from clearml import Task
from tensorboardX import SummaryWriter
class Net(M.Module): class Net(M.Module):
@ -55,11 +60,7 @@ def build_dataloader():
train_dataset = MNIST(root=gettempdir(), train=True, download=True) train_dataset = MNIST(root=gettempdir(), train=True, download=True)
dataloader = DataLoader( dataloader = DataLoader(
train_dataset, train_dataset,
transform=Compose([ transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]),
Normalize(mean=0.1307*255, std=0.3081*255),
Pad(2),
ToMode('CHW'),
]),
sampler=RandomSampler(dataset=train_dataset, batch_size=64), sampler=RandomSampler(dataset=train_dataset, batch_size=64),
) )
return dataloader return dataloader
@ -69,10 +70,7 @@ def train(dataloader, args):
writer = SummaryWriter("runs") writer = SummaryWriter("runs")
net = Net() net = Net()
net.train() net.train()
optimizer = SGD( optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
net.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd
)
gm = GradManager().attach(net.parameters()) gm = GradManager().attach(net.parameters())
epoch_length = len(dataloader) epoch_length = len(dataloader)
@ -90,28 +88,24 @@ def train(dataloader, args):
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
writer.add_scalar("loss", float(loss), epoch * epoch_length + step) writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa mge.save(
net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
) # noqa
def main(): def main():
task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa
parser = argparse.ArgumentParser(description='MegEngine MNIST Example') parser = argparse.ArgumentParser(description="MegEngine MNIST Example")
parser.add_argument( parser.add_argument(
'--epoch', type=int, default=10, "--epoch", type=int, default=10, help="number of training epoch(default: 10)",
help='number of training epoch(default: 10)', )
parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)")
parser.add_argument(
"--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)",
) )
parser.add_argument( parser.add_argument(
'--lr', type=float, default=0.01, "--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)",
help='learning rate(default: 0.01)'
)
parser.add_argument(
'--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)',
)
parser.add_argument(
'--wd', type=float, default=5e-4,
help='SGD weight decay(default: 5e-4)',
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,3 +1,3 @@
MegEngine MegEngine ; python_version < '3.9'
tensorboardX tensorboardX
clearml clearml

View File

@ -3,4 +3,5 @@ tensorboardX
tensorboard>=1.14.0 tensorboard>=1.14.0
torch>=1.1.0 torch>=1.1.0
torchvision>=0.3.0 torchvision>=0.3.0
tqdm
clearml clearml