mirror of
https://github.com/clearml/clearml
synced 2025-03-10 05:50:13 +00:00
Add support for MegEngine (#455)
* feat(framework): support MegEngine model * feat(framework): add MegEngine example * fix(example): change supported suffix to pkl
This commit is contained in:
parent
400c6ec103
commit
8c9c0eacc7
176
clearml/binding/frameworks/megengine_bind.py
Normal file
176
clearml/binding/frameworks/megengine_bind.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ..frameworks.base_bind import PatchBaseModelIO
|
||||||
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
# __patched_lightning = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **_):
|
||||||
|
PatchMegEngineModelIO.__main_task = task
|
||||||
|
PatchMegEngineModelIO._patch_model_io()
|
||||||
|
PostImportHookPatching.add_on_import(
|
||||||
|
'megengine', PatchMegEngineModelIO._patch_model_io
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
if PatchMegEngineModelIO.__patched:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'megengine' not in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
PatchMegEngineModelIO.__patched = True
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import megengine as mge # noqa
|
||||||
|
mge.save = _patched_call(mge.save, PatchMegEngineModelIO._save)
|
||||||
|
mge.load = _patched_call(mge.load, PatchMegEngineModelIO._load)
|
||||||
|
|
||||||
|
# no need to worry about recursive calls, _patched_call takes care of that # noqa
|
||||||
|
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'save'): # noqa
|
||||||
|
mge.serialization.save = _patched_call(
|
||||||
|
mge.serialization.save, PatchMegEngineModelIO._save
|
||||||
|
)
|
||||||
|
if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'load'): # noqa
|
||||||
|
mge.serialization.load = _patched_call(
|
||||||
|
mge.serialization.load, PatchMegEngineModelIO._load,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
|
if not PatchMegEngineModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'as_posix'):
|
||||||
|
filename = f.as_posix()
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
f.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not isinstance(f.name, six.string_types):
|
||||||
|
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing) # noqa
|
||||||
|
return ret
|
||||||
|
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
except Exception:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem if filename is not None else None
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
|
||||||
|
WeightsFileHandler.create_output_model(
|
||||||
|
obj, filename, Framework.megengine,
|
||||||
|
PatchMegEngineModelIO.__main_task,
|
||||||
|
singlefile=True, model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
|
if not PatchMegEngineModelIO.__main_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'as_posix'):
|
||||||
|
filename = f.as_posix()
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
except Exception:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, filename, Framework.megengine,
|
||||||
|
PatchMegEngineModelIO.__main_task
|
||||||
|
)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
|
if not PatchMegEngineModelIO.__main_task:
|
||||||
|
return original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'as_posix'):
|
||||||
|
filename = f.as_posix()
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
except Exception:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(obj, f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, filename, Framework.megengine,
|
||||||
|
PatchMegEngineModelIO.__main_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return model
|
@ -49,6 +49,7 @@ class Framework(Options):
|
|||||||
xgboost = 'XGBoost'
|
xgboost = 'XGBoost'
|
||||||
lightgbm = 'LightGBM'
|
lightgbm = 'LightGBM'
|
||||||
parquet = 'Parquet'
|
parquet = 'Parquet'
|
||||||
|
megengine = 'MegEngine'
|
||||||
|
|
||||||
__file_extensions_mapping = {
|
__file_extensions_mapping = {
|
||||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||||
@ -75,7 +76,7 @@ class Framework(Options):
|
|||||||
'.t7': (torch, ),
|
'.t7': (torch, ),
|
||||||
'.cfg': (darknet, ),
|
'.cfg': (darknet, ),
|
||||||
'__model__': (paddlepaddle, ),
|
'__model__': (paddlepaddle, ),
|
||||||
'.pkl': (scikitlearn, keras, xgboost),
|
'.pkl': (scikitlearn, keras, xgboost, megengine),
|
||||||
'.parquet': (parquet, ),
|
'.parquet': (parquet, ),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
|
|||||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||||
|
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
|
||||||
from .binding.joblib_bind import PatchedJoblib
|
from .binding.joblib_bind import PatchedJoblib
|
||||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
from .binding.hydra_bind import PatchHydra
|
from .binding.hydra_bind import PatchHydra
|
||||||
@ -565,6 +566,8 @@ class Task(_Task):
|
|||||||
)
|
)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
|
||||||
|
PatchMegEngineModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||||
PatchXGBoostModelIO.update_current_task(task)
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||||
|
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
123
examples/frameworks/megengine/megengine_mnist.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from tempfile import gettempdir
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import megengine as mge
|
||||||
|
import megengine.module as M
|
||||||
|
import megengine.functional as F
|
||||||
|
from megengine.optimizer import SGD
|
||||||
|
from megengine.autodiff import GradManager
|
||||||
|
|
||||||
|
from megengine.data import DataLoader, RandomSampler
|
||||||
|
from megengine.data.transform import ToMode, Pad, Normalize, Compose
|
||||||
|
from megengine.data.dataset import MNIST
|
||||||
|
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
|
||||||
|
class Net(M.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False)
|
||||||
|
self.bn0 = M.BatchNorm2d(20)
|
||||||
|
self.relu0 = M.ReLU()
|
||||||
|
self.pool0 = M.MaxPool2d(2)
|
||||||
|
self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False)
|
||||||
|
self.bn1 = M.BatchNorm2d(20)
|
||||||
|
self.relu1 = M.ReLU()
|
||||||
|
self.pool1 = M.MaxPool2d(2)
|
||||||
|
self.fc0 = M.Linear(500, 64, bias=True)
|
||||||
|
self.relu2 = M.ReLU()
|
||||||
|
self.fc1 = M.Linear(64, 10, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv0(x)
|
||||||
|
x = self.bn0(x)
|
||||||
|
x = self.relu0(x)
|
||||||
|
x = self.pool0(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
x = self.pool1(x)
|
||||||
|
x = F.flatten(x, 1)
|
||||||
|
x = self.fc0(x)
|
||||||
|
x = self.relu2(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataloader():
|
||||||
|
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
transform=Compose([
|
||||||
|
Normalize(mean=0.1307*255, std=0.3081*255),
|
||||||
|
Pad(2),
|
||||||
|
ToMode('CHW'),
|
||||||
|
]),
|
||||||
|
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
||||||
|
)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def train(dataloader, args):
|
||||||
|
writer = SummaryWriter("runs")
|
||||||
|
net = Net()
|
||||||
|
net.train()
|
||||||
|
optimizer = SGD(
|
||||||
|
net.parameters(), lr=args.lr,
|
||||||
|
momentum=args.momentum, weight_decay=args.wd
|
||||||
|
)
|
||||||
|
gm = GradManager().attach(net.parameters())
|
||||||
|
|
||||||
|
epoch_length = len(dataloader)
|
||||||
|
for epoch in range(args.epoch):
|
||||||
|
for step, (batch_data, batch_label) in enumerate(dataloader):
|
||||||
|
batch_label = batch_label.astype(np.int32)
|
||||||
|
data, label = mge.tensor(batch_data), mge.tensor(batch_label)
|
||||||
|
with gm:
|
||||||
|
pred = net(data)
|
||||||
|
loss = F.loss.cross_entropy(pred, label)
|
||||||
|
gm.backward(loss)
|
||||||
|
optimizer.step().clear_grad()
|
||||||
|
|
||||||
|
if step % 50 == 0:
|
||||||
|
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
||||||
|
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
task = Task.init(project_name='megengine', task_name='mge mnist train') # noqa
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
|
||||||
|
parser.add_argument(
|
||||||
|
'--epoch', type=int, default=10,
|
||||||
|
help='number of training epoch(default: 10)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--lr', type=float, default=0.01,
|
||||||
|
help='learning rate(default: 0.01)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--momentum', type=float, default=0.9,
|
||||||
|
help='SGD momentum (default: 0.9)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--wd', type=float, default=5e-4,
|
||||||
|
help='SGD weight decay(default: 5e-4)',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataloader = build_dataloader()
|
||||||
|
train(dataloader, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user