#!/usr/bin/env python3 # -*- coding:utf-8 -*- import argparse import os from tempfile import gettempdir import numpy as np import megengine as mge import megengine.module as M import megengine.functional as F from megengine.optimizer import SGD from megengine.autodiff import GradManager from megengine.data import DataLoader, RandomSampler from megengine.data.transform import ToMode, Pad, Normalize, Compose from megengine.data.dataset import MNIST from tensorboardX import SummaryWriter from clearml import Task class Net(M.Module): def __init__(self): super().__init__() self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False) self.bn0 = M.BatchNorm2d(20) self.relu0 = M.ReLU() self.pool0 = M.MaxPool2d(2) self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False) self.bn1 = M.BatchNorm2d(20) self.relu1 = M.ReLU() self.pool1 = M.MaxPool2d(2) self.fc0 = M.Linear(500, 64, bias=True) self.relu2 = M.ReLU() self.fc1 = M.Linear(64, 10, bias=True) def forward(self, x): x = self.conv0(x) x = self.bn0(x) x = self.relu0(x) x = self.pool0(x) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = F.flatten(x, 1) x = self.fc0(x) x = self.relu2(x) x = self.fc1(x) return x def build_dataloader(): train_dataset = MNIST(root=gettempdir(), train=True, download=True) dataloader = DataLoader( train_dataset, transform=Compose([ Normalize(mean=0.1307*255, std=0.3081*255), Pad(2), ToMode('CHW'), ]), sampler=RandomSampler(dataset=train_dataset, batch_size=64), ) return dataloader def train(dataloader, args): writer = SummaryWriter("runs") net = Net() net.train() optimizer = SGD( net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd ) gm = GradManager().attach(net.parameters()) epoch_length = len(dataloader) for epoch in range(args.epoch): for step, (batch_data, batch_label) in enumerate(dataloader): batch_label = batch_label.astype(np.int32) data, label = mge.tensor(batch_data), mge.tensor(batch_label) with gm: pred = net(data) loss = F.loss.cross_entropy(pred, label) gm.backward(loss) optimizer.step().clear_grad() if step % 50 == 0: print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa writer.add_scalar("loss", float(loss), epoch * epoch_length + step) if (epoch + 1) % 5 == 0: mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa def main(): task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa parser = argparse.ArgumentParser(description='MegEngine MNIST Example') parser.add_argument( '--epoch', type=int, default=10, help='number of training epoch(default: 10)', ) parser.add_argument( '--lr', type=float, default=0.01, help='learning rate(default: 0.01)' ) parser.add_argument( '--momentum', type=float, default=0.9, help='SGD momentum (default: 0.9)', ) parser.add_argument( '--wd', type=float, default=5e-4, help='SGD weight decay(default: 5e-4)', ) args = parser.parse_args() dataloader = build_dataloader() train(dataloader, args) if __name__ == "__main__": main()