import argparse import os import sys from tempfile import gettempdir import numpy as np try: import megengine as mge import megengine.functional as F import megengine.module as M from megengine.autodiff import GradManager from megengine.data import DataLoader, RandomSampler from megengine.data.dataset import MNIST from megengine.data.transform import Compose, Normalize, Pad, ToMode from megengine.optimizer import SGD except ImportError: raise ImportError( "megengine package is missing, you can install it using pip: pip install megengine" if sys.version_info.minor <= 8 else "MegEngine does not support python version >= 3.9" ) from clearml import Task from tensorboardX import SummaryWriter class Net(M.Module): def __init__(self): super().__init__() self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False) self.bn0 = M.BatchNorm2d(20) self.relu0 = M.ReLU() self.pool0 = M.MaxPool2d(2) self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False) self.bn1 = M.BatchNorm2d(20) self.relu1 = M.ReLU() self.pool1 = M.MaxPool2d(2) self.fc0 = M.Linear(500, 64, bias=True) self.relu2 = M.ReLU() self.fc1 = M.Linear(64, 10, bias=True) def forward(self, x): x = self.conv0(x) x = self.bn0(x) x = self.relu0(x) x = self.pool0(x) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = F.flatten(x, 1) x = self.fc0(x) x = self.relu2(x) x = self.fc1(x) return x def build_dataloader(): train_dataset = MNIST(root=gettempdir(), train=True, download=True) dataloader = DataLoader( train_dataset, transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]), sampler=RandomSampler(dataset=train_dataset, batch_size=64), ) return dataloader def train(dataloader, args): writer = SummaryWriter("runs") net = Net() net.train() optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) gm = GradManager().attach(net.parameters()) epoch_length = len(dataloader) for epoch in range(args.epoch): for step, (batch_data, batch_label) in enumerate(dataloader): batch_label = batch_label.astype(np.int32) data, label = mge.tensor(batch_data), mge.tensor(batch_label) with gm: pred = net(data) loss = F.loss.cross_entropy(pred, label) gm.backward(loss) optimizer.step().clear_grad() if step % 50 == 0: print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa writer.add_scalar("loss", float(loss), epoch * epoch_length + step) if (epoch + 1) % 5 == 0: mge.save( net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"), ) # noqa def main(): task = Task.init(project_name="examples", task_name="MegEngine MNIST train") # noqa parser = argparse.ArgumentParser(description="MegEngine MNIST Example") parser.add_argument( "--epoch", type=int, default=10, help="number of training epoch(default: 10)", ) parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)") parser.add_argument( "--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)", ) parser.add_argument( "--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)", ) args = parser.parse_args() dataloader = build_dataloader() train(dataloader, args) if __name__ == "__main__": main()