Fixed requirements, refactored and formatted code in some examples (#567)

This commit is contained in:
Rizwan Hasan 2022-02-04 01:34:42 +06:00 committed by GitHub
parent 172c3e44f1
commit eb5350f551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 39 deletions

View File

@ -1,23 +1,28 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import sys
from tempfile import gettempdir
import numpy as np
try:
import megengine as mge
import megengine.module as M
import megengine.functional as F
from megengine.optimizer import SGD
import megengine.module as M
from megengine.autodiff import GradManager
from megengine.data import DataLoader, RandomSampler
from megengine.data.transform import ToMode, Pad, Normalize, Compose
from megengine.data.dataset import MNIST
from megengine.data.transform import Compose, Normalize, Pad, ToMode
from megengine.optimizer import SGD
except ImportError:
raise ImportError(
"megengine package is missing, you can install it using pip: pip install megengine"
if sys.version_info.minor <= 8
else "MegEngine does not support python version >= 3.9"
)
from tensorboardX import SummaryWriter
from clearml import Task
from tensorboardX import SummaryWriter
class Net(M.Module):
@ -55,11 +60,7 @@ def build_dataloader():
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
dataloader = DataLoader(
train_dataset,
transform=Compose([
Normalize(mean=0.1307*255, std=0.3081*255),
Pad(2),
ToMode('CHW'),
]),
transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]),
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
)
return dataloader
@ -69,10 +70,7 @@ def train(dataloader, args):
writer = SummaryWriter("runs")
net = Net()
net.train()
optimizer = SGD(
net.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd
)
optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
gm = GradManager().attach(net.parameters())
epoch_length = len(dataloader)
@ -90,28 +88,24 @@ def train(dataloader, args):
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
if (epoch + 1) % 5 == 0:
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
mge.save(
net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
) # noqa
def main():
task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa
task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
parser = argparse.ArgumentParser(description="MegEngine MNIST Example")
parser.add_argument(
'--epoch', type=int, default=10,
help='number of training epoch(default: 10)',
"--epoch", type=int, default=10, help="number of training epoch(default: 10)",
)
parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)")
parser.add_argument(
"--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)",
)
parser.add_argument(
'--lr', type=float, default=0.01,
help='learning rate(default: 0.01)'
)
parser.add_argument(
'--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)',
)
parser.add_argument(
'--wd', type=float, default=5e-4,
help='SGD weight decay(default: 5e-4)',
"--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)",
)
args = parser.parse_args()

View File

@ -1,3 +1,3 @@
MegEngine
MegEngine ; python_version < '3.9'
tensorboardX
clearml

View File

@ -3,4 +3,5 @@ tensorboardX
tensorboard>=1.14.0
torch>=1.1.0
torchvision>=0.3.0
tqdm
clearml