mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add support for MegEngine (#455)
* feat(framework): support MegEngine model * feat(framework): add MegEngine example * fix(example): change supported suffix to pkl
This commit is contained in:
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
import numpy as np
|
||||
|
||||
import megengine as mge
|
||||
import megengine.module as M
|
||||
import megengine.functional as F
|
||||
from megengine.optimizer import SGD
|
||||
from megengine.autodiff import GradManager
|
||||
|
||||
from megengine.data import DataLoader, RandomSampler
|
||||
from megengine.data.transform import ToMode, Pad, Normalize, Compose
|
||||
from megengine.data.dataset import MNIST
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from clearml import Task
|
||||
|
||||
|
||||
class Net(M.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False)
|
||||
self.bn0 = M.BatchNorm2d(20)
|
||||
self.relu0 = M.ReLU()
|
||||
self.pool0 = M.MaxPool2d(2)
|
||||
self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False)
|
||||
self.bn1 = M.BatchNorm2d(20)
|
||||
self.relu1 = M.ReLU()
|
||||
self.pool1 = M.MaxPool2d(2)
|
||||
self.fc0 = M.Linear(500, 64, bias=True)
|
||||
self.relu2 = M.ReLU()
|
||||
self.fc1 = M.Linear(64, 10, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv0(x)
|
||||
x = self.bn0(x)
|
||||
x = self.relu0(x)
|
||||
x = self.pool0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.pool1(x)
|
||||
x = F.flatten(x, 1)
|
||||
x = self.fc0(x)
|
||||
x = self.relu2(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
def build_dataloader():
|
||||
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
transform=Compose([
|
||||
Normalize(mean=0.1307*255, std=0.3081*255),
|
||||
Pad(2),
|
||||
ToMode('CHW'),
|
||||
]),
|
||||
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def train(dataloader, args):
|
||||
writer = SummaryWriter("runs")
|
||||
net = Net()
|
||||
net.train()
|
||||
optimizer = SGD(
|
||||
net.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.wd
|
||||
)
|
||||
gm = GradManager().attach(net.parameters())
|
||||
|
||||
epoch_length = len(dataloader)
|
||||
for epoch in range(args.epoch):
|
||||
for step, (batch_data, batch_label) in enumerate(dataloader):
|
||||
batch_label = batch_label.astype(np.int32)
|
||||
data, label = mge.tensor(batch_data), mge.tensor(batch_label)
|
||||
with gm:
|
||||
pred = net(data)
|
||||
loss = F.loss.cross_entropy(pred, label)
|
||||
gm.backward(loss)
|
||||
optimizer.step().clear_grad()
|
||||
|
||||
if step % 50 == 0:
|
||||
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
||||
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
||||
if (epoch + 1) % 5 == 0:
|
||||
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
|
||||
|
||||
|
||||
def main():
|
||||
task = Task.init(project_name='megengine', task_name='mge mnist train') # noqa
|
||||
|
||||
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
|
||||
parser.add_argument(
|
||||
'--epoch', type=int, default=10,
|
||||
help='number of training epoch(default: 10)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr', type=float, default=0.01,
|
||||
help='learning rate(default: 0.01)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--momentum', type=float, default=0.9,
|
||||
help='SGD momentum (default: 0.9)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--wd', type=float, default=5e-4,
|
||||
help='SGD weight decay(default: 5e-4)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dataloader = build_dataloader()
|
||||
train(dataloader, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user