mirror of
https://github.com/clearml/clearml
synced 2025-03-09 21:40:51 +00:00
Add support for MegEngine (#455)
* feat(framework): support MegEngine model * feat(framework): add MegEngine example * fix(example): change supported suffix to pkl
This commit is contained in:
parent
400c6ec103
commit
8c9c0eacc7
176
clearml/binding/frameworks/megengine_bind.py
Normal file
176
clearml/binding/frameworks/megengine_bind.py
Normal file
@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import sys
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..frameworks.base_bind import PatchBaseModelIO
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...model import Framework
|
||||
|
||||
|
||||
class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
# __patched_lightning = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchMegEngineModelIO.__main_task = task
|
||||
PatchMegEngineModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import(
|
||||
'megengine', PatchMegEngineModelIO._patch_model_io
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchMegEngineModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'megengine' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchMegEngineModelIO.__patched = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import megengine as mge # noqa
|
||||
mge.save = _patched_call(mge.save, PatchMegEngineModelIO._save)
|
||||
mge.load = _patched_call(mge.load, PatchMegEngineModelIO._load)
|
||||
|
||||
# no need to worry about recursive calls, _patched_call takes care of that # noqa
|
||||
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'save'): # noqa
|
||||
mge.serialization.save = _patched_call(
|
||||
mge.serialization.save, PatchMegEngineModelIO._save
|
||||
)
|
||||
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'load'): # noqa
|
||||
mge.serialization.load = _patched_call(
|
||||
mge.serialization.load, PatchMegEngineModelIO._load,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(f.name, six.string_types):
|
||||
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing) # noqa
|
||||
return ret
|
||||
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem if filename is not None else None
|
||||
except Exception:
|
||||
model_name = None
|
||||
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
filename = None
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
return original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
filename = None
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(obj, f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task,
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return model
|
@ -49,6 +49,7 @@ class Framework(Options):
|
||||
xgboost = 'XGBoost'
|
||||
lightgbm = 'LightGBM'
|
||||
parquet = 'Parquet'
|
||||
megengine = 'MegEngine'
|
||||
|
||||
__file_extensions_mapping = {
|
||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||
@ -75,7 +76,7 @@ class Framework(Options):
|
||||
'.t7': (torch, ),
|
||||
'.cfg': (darknet, ),
|
||||
'__model__': (paddlepaddle, ),
|
||||
'.pkl': (scikitlearn, keras, xgboost),
|
||||
'.pkl': (scikitlearn, keras, xgboost, megengine),
|
||||
'.parquet': (parquet, ),
|
||||
}
|
||||
|
||||
|
@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
|
||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
|
||||
from .binding.joblib_bind import PatchedJoblib
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .binding.hydra_bind import PatchHydra
|
||||
@ -565,6 +566,8 @@ class Task(_Task):
|
||||
)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
|
||||
PatchMegEngineModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||
|
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
import numpy as np
|
||||
|
||||
import megengine as mge
|
||||
import megengine.module as M
|
||||
import megengine.functional as F
|
||||
from megengine.optimizer import SGD
|
||||
from megengine.autodiff import GradManager
|
||||
|
||||
from megengine.data import DataLoader, RandomSampler
|
||||
from megengine.data.transform import ToMode, Pad, Normalize, Compose
|
||||
from megengine.data.dataset import MNIST
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from clearml import Task
|
||||
|
||||
|
||||
class Net(M.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False)
|
||||
self.bn0 = M.BatchNorm2d(20)
|
||||
self.relu0 = M.ReLU()
|
||||
self.pool0 = M.MaxPool2d(2)
|
||||
self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False)
|
||||
self.bn1 = M.BatchNorm2d(20)
|
||||
self.relu1 = M.ReLU()
|
||||
self.pool1 = M.MaxPool2d(2)
|
||||
self.fc0 = M.Linear(500, 64, bias=True)
|
||||
self.relu2 = M.ReLU()
|
||||
self.fc1 = M.Linear(64, 10, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv0(x)
|
||||
x = self.bn0(x)
|
||||
x = self.relu0(x)
|
||||
x = self.pool0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.pool1(x)
|
||||
x = F.flatten(x, 1)
|
||||
x = self.fc0(x)
|
||||
x = self.relu2(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
def build_dataloader():
|
||||
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
transform=Compose([
|
||||
Normalize(mean=0.1307*255, std=0.3081*255),
|
||||
Pad(2),
|
||||
ToMode('CHW'),
|
||||
]),
|
||||
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def train(dataloader, args):
|
||||
writer = SummaryWriter("runs")
|
||||
net = Net()
|
||||
net.train()
|
||||
optimizer = SGD(
|
||||
net.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.wd
|
||||
)
|
||||
gm = GradManager().attach(net.parameters())
|
||||
|
||||
epoch_length = len(dataloader)
|
||||
for epoch in range(args.epoch):
|
||||
for step, (batch_data, batch_label) in enumerate(dataloader):
|
||||
batch_label = batch_label.astype(np.int32)
|
||||
data, label = mge.tensor(batch_data), mge.tensor(batch_label)
|
||||
with gm:
|
||||
pred = net(data)
|
||||
loss = F.loss.cross_entropy(pred, label)
|
||||
gm.backward(loss)
|
||||
optimizer.step().clear_grad()
|
||||
|
||||
if step % 50 == 0:
|
||||
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
||||
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
||||
if (epoch + 1) % 5 == 0:
|
||||
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
|
||||
|
||||
|
||||
def main():
|
||||
task = Task.init(project_name='megengine', task_name='mge mnist train') # noqa
|
||||
|
||||
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
|
||||
parser.add_argument(
|
||||
'--epoch', type=int, default=10,
|
||||
help='number of training epoch(default: 10)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr', type=float, default=0.01,
|
||||
help='learning rate(default: 0.01)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--momentum', type=float, default=0.9,
|
||||
help='SGD momentum (default: 0.9)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--wd', type=float, default=5e-4,
|
||||
help='SGD weight decay(default: 5e-4)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dataloader = build_dataloader()
|
||||
train(dataloader, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user