Add support for MegEngine (#455)

* feat(framework): support MegEngine model

* feat(framework): add MegEngine example

* fix(example): change supported suffix to pkl
This commit is contained in:
Feng Wang 2021-09-30 05:16:20 +08:00 committed by GitHub
parent 400c6ec103
commit 8c9c0eacc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 304 additions and 1 deletions

View File

@ -0,0 +1,176 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import sys
import six
from pathlib2 import Path
from ..frameworks.base_bind import PatchBaseModelIO
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...model import Framework
class PatchMegEngineModelIO(PatchBaseModelIO):
__main_task = None
__patched = None
# __patched_lightning = None
@staticmethod
def update_current_task(task, **_):
PatchMegEngineModelIO.__main_task = task
PatchMegEngineModelIO._patch_model_io()
PostImportHookPatching.add_on_import(
'megengine', PatchMegEngineModelIO._patch_model_io
)
@staticmethod
def _patch_model_io():
if PatchMegEngineModelIO.__patched:
return
if 'megengine' not in sys.modules:
return
PatchMegEngineModelIO.__patched = True
# noinspection PyBroadException
try:
import megengine as mge # noqa
mge.save = _patched_call(mge.save, PatchMegEngineModelIO._save)
mge.load = _patched_call(mge.load, PatchMegEngineModelIO._load)
# no need to worry about recursive calls, _patched_call takes care of that # noqa
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'save'): # noqa
mge.serialization.save = _patched_call(
mge.serialization.save, PatchMegEngineModelIO._save
)
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'load'): # noqa
mge.serialization.load = _patched_call(
mge.serialization.load, PatchMegEngineModelIO._load,
)
except ImportError:
pass
except Exception:
pass
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
return ret
# noinspection PyBroadException
try:
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
if not isinstance(f.name, six.string_types):
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing) # noqa
return ret
filename = f.name
else:
filename = None
except Exception:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem if filename is not None else None
except Exception:
model_name = None
WeightsFileHandler.create_output_model(
obj, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task,
singlefile=True, model_name=model_name,
)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# noinspection PyBroadException
try:
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
except Exception:
filename = None
# register input model
empty = _Empty()
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task
)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
@staticmethod
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
return original_fn(obj, f, *args, **kwargs)
# noinspection PyBroadException
try:
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
except Exception:
filename = None
# register input model
empty = _Empty()
# try to load model before registering, in case we fail
model = original_fn(obj, f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task,
)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model

View File

@ -49,6 +49,7 @@ class Framework(Options):
xgboost = 'XGBoost'
lightgbm = 'LightGBM'
parquet = 'Parquet'
megengine = 'MegEngine'
__file_extensions_mapping = {
'.pb': (tensorflow, tensorflowjs, onnx, ),
@ -75,7 +76,7 @@ class Framework(Options):
'.t7': (torch, ),
'.cfg': (darknet, ),
'__model__': (paddlepaddle, ),
'.pkl': (scikitlearn, keras, xgboost),
'.pkl': (scikitlearn, keras, xgboost, megengine),
'.parquet': (parquet, ),
}

View File

@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from .binding.frameworks.tensorflow_bind import TensorflowBinding
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
from .binding.joblib_bind import PatchedJoblib
from .binding.matplotlib_bind import PatchedMatplotlib
from .binding.hydra_bind import PatchHydra
@ -565,6 +566,8 @@ class Task(_Task):
)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
PatchMegEngineModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
PatchXGBoostModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
from tempfile import gettempdir
import numpy as np
import megengine as mge
import megengine.module as M
import megengine.functional as F
from megengine.optimizer import SGD
from megengine.autodiff import GradManager
from megengine.data import DataLoader, RandomSampler
from megengine.data.transform import ToMode, Pad, Normalize, Compose
from megengine.data.dataset import MNIST
from tensorboardX import SummaryWriter
from clearml import Task
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False)
self.bn0 = M.BatchNorm2d(20)
self.relu0 = M.ReLU()
self.pool0 = M.MaxPool2d(2)
self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False)
self.bn1 = M.BatchNorm2d(20)
self.relu1 = M.ReLU()
self.pool1 = M.MaxPool2d(2)
self.fc0 = M.Linear(500, 64, bias=True)
self.relu2 = M.ReLU()
self.fc1 = M.Linear(64, 10, bias=True)
def forward(self, x):
x = self.conv0(x)
x = self.bn0(x)
x = self.relu0(x)
x = self.pool0(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = F.flatten(x, 1)
x = self.fc0(x)
x = self.relu2(x)
x = self.fc1(x)
return x
def build_dataloader():
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
dataloader = DataLoader(
train_dataset,
transform=Compose([
Normalize(mean=0.1307*255, std=0.3081*255),
Pad(2),
ToMode('CHW'),
]),
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
)
return dataloader
def train(dataloader, args):
writer = SummaryWriter("runs")
net = Net()
net.train()
optimizer = SGD(
net.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd
)
gm = GradManager().attach(net.parameters())
epoch_length = len(dataloader)
for epoch in range(args.epoch):
for step, (batch_data, batch_label) in enumerate(dataloader):
batch_label = batch_label.astype(np.int32)
data, label = mge.tensor(batch_data), mge.tensor(batch_label)
with gm:
pred = net(data)
loss = F.loss.cross_entropy(pred, label)
gm.backward(loss)
optimizer.step().clear_grad()
if step % 50 == 0:
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
if (epoch + 1) % 5 == 0:
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
def main():
task = Task.init(project_name='megengine', task_name='mge mnist train') # noqa
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
parser.add_argument(
'--epoch', type=int, default=10,
help='number of training epoch(default: 10)',
)
parser.add_argument(
'--lr', type=float, default=0.01,
help='learning rate(default: 0.01)'
)
parser.add_argument(
'--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)',
)
parser.add_argument(
'--wd', type=float, default=5e-4,
help='SGD weight decay(default: 5e-4)',
)
args = parser.parse_args()
dataloader = build_dataloader()
train(dataloader, args)
if __name__ == "__main__":
main()