diff --git a/clearml/binding/frameworks/megengine_bind.py b/clearml/binding/frameworks/megengine_bind.py new file mode 100644 index 00000000..15fe422f --- /dev/null +++ b/clearml/binding/frameworks/megengine_bind.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import sys + +import six +from pathlib2 import Path + +from ..frameworks.base_bind import PatchBaseModelIO +from ..frameworks import _patched_call, WeightsFileHandler, _Empty +from ..import_bind import PostImportHookPatching +from ...model import Framework + + +class PatchMegEngineModelIO(PatchBaseModelIO): + __main_task = None + __patched = None + # __patched_lightning = None + + @staticmethod + def update_current_task(task, **_): + PatchMegEngineModelIO.__main_task = task + PatchMegEngineModelIO._patch_model_io() + PostImportHookPatching.add_on_import( + 'megengine', PatchMegEngineModelIO._patch_model_io + ) + + @staticmethod + def _patch_model_io(): + if PatchMegEngineModelIO.__patched: + return + + if 'megengine' not in sys.modules: + return + + PatchMegEngineModelIO.__patched = True + + # noinspection PyBroadException + try: + import megengine as mge # noqa + mge.save = _patched_call(mge.save, PatchMegEngineModelIO._save) + mge.load = _patched_call(mge.load, PatchMegEngineModelIO._load) + + # no need to worry about recursive calls, _patched_call takes care of that # noqa + if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'save'): # noqa + mge.serialization.save = _patched_call( + mge.serialization.save, PatchMegEngineModelIO._save + ) + if hasattr(mge, 'serialization') and hasattr(mge.serialization, 'load'): # noqa + mge.serialization.load = _patched_call( + mge.serialization.load, PatchMegEngineModelIO._load, + ) + except ImportError: + pass + except Exception: + pass + + @staticmethod + def _save(original_fn, obj, f, *args, **kwargs): + ret = original_fn(obj, f, *args, **kwargs) + + # if there is no main task or this is a nested call + if not PatchMegEngineModelIO.__main_task: + return ret + + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + # noinspection PyBroadException + try: + f.flush() + except Exception: + pass + + if not isinstance(f.name, six.string_types): + # Probably a BufferedRandom object that has no meaningful name (still no harm flushing) # noqa + return ret + + filename = f.name + else: + filename = None + except Exception: + filename = None + + # give the model a descriptive name based on the file name + # noinspection PyBroadException + try: + model_name = Path(filename).stem if filename is not None else None + except Exception: + model_name = None + + WeightsFileHandler.create_output_model( + obj, filename, Framework.megengine, + PatchMegEngineModelIO.__main_task, + singlefile=True, model_name=model_name, + ) + + return ret + + @staticmethod + def _load(original_fn, f, *args, **kwargs): + # if there is no main task or this is a nested call + if not PatchMegEngineModelIO.__main_task: + return original_fn(f, *args, **kwargs) + + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + filename = f.name + else: + filename = None + except Exception: + filename = None + + # register input model + empty = _Empty() + # try to load model before registering, in case we fail + model = original_fn(f, *args, **kwargs) + WeightsFileHandler.restore_weights_file( + empty, filename, Framework.megengine, + PatchMegEngineModelIO.__main_task + ) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + + return model + + @staticmethod + def _load_from_obj(original_fn, obj, f, *args, **kwargs): + # if there is no main task or this is a nested call + if not PatchMegEngineModelIO.__main_task: + return original_fn(obj, f, *args, **kwargs) + + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + filename = f.name + else: + filename = None + except Exception: + filename = None + + # register input model + empty = _Empty() + # try to load model before registering, in case we fail + model = original_fn(obj, f, *args, **kwargs) + WeightsFileHandler.restore_weights_file( + empty, filename, Framework.megengine, + PatchMegEngineModelIO.__main_task, + ) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + + return model diff --git a/clearml/model.py b/clearml/model.py index 33e7f5a6..9471472d 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -49,6 +49,7 @@ class Framework(Options): xgboost = 'XGBoost' lightgbm = 'LightGBM' parquet = 'Parquet' + megengine = 'MegEngine' __file_extensions_mapping = { '.pb': (tensorflow, tensorflowjs, onnx, ), @@ -75,7 +76,7 @@ class Framework(Options): '.t7': (torch, ), '.cfg': (darknet, ), '__model__': (paddlepaddle, ), - '.pkl': (scikitlearn, keras, xgboost), + '.pkl': (scikitlearn, keras, xgboost, megengine), '.parquet': (parquet, ), } diff --git a/clearml/task.py b/clearml/task.py index 75f485cb..93fd14cf 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO from .binding.frameworks.tensorflow_bind import TensorflowBinding from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO +from .binding.frameworks.megengine_bind import PatchMegEngineModelIO from .binding.joblib_bind import PatchedJoblib from .binding.matplotlib_bind import PatchedMatplotlib from .binding.hydra_bind import PatchHydra @@ -565,6 +566,8 @@ class Task(_Task): ) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True): PatchPyTorchModelIO.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True): + PatchMegEngineModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): PatchXGBoostModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True): diff --git a/examples/frameworks/megengine/megengine_mnist.py b/examples/frameworks/megengine/megengine_mnist.py new file mode 100644 index 00000000..b1836d93 --- /dev/null +++ b/examples/frameworks/megengine/megengine_mnist.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import argparse +import os +from tempfile import gettempdir +import numpy as np + +import megengine as mge +import megengine.module as M +import megengine.functional as F +from megengine.optimizer import SGD +from megengine.autodiff import GradManager + +from megengine.data import DataLoader, RandomSampler +from megengine.data.transform import ToMode, Pad, Normalize, Compose +from megengine.data.dataset import MNIST + +from tensorboardX import SummaryWriter +from clearml import Task + + +class Net(M.Module): + def __init__(self): + super().__init__() + self.conv0 = M.Conv2d(1, 20, kernel_size=5, bias=False) + self.bn0 = M.BatchNorm2d(20) + self.relu0 = M.ReLU() + self.pool0 = M.MaxPool2d(2) + self.conv1 = M.Conv2d(20, 20, kernel_size=5, bias=False) + self.bn1 = M.BatchNorm2d(20) + self.relu1 = M.ReLU() + self.pool1 = M.MaxPool2d(2) + self.fc0 = M.Linear(500, 64, bias=True) + self.relu2 = M.ReLU() + self.fc1 = M.Linear(64, 10, bias=True) + + def forward(self, x): + x = self.conv0(x) + x = self.bn0(x) + x = self.relu0(x) + x = self.pool0(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.pool1(x) + x = F.flatten(x, 1) + x = self.fc0(x) + x = self.relu2(x) + x = self.fc1(x) + return x + + +def build_dataloader(): + train_dataset = MNIST(root=gettempdir(), train=True, download=True) + dataloader = DataLoader( + train_dataset, + transform=Compose([ + Normalize(mean=0.1307*255, std=0.3081*255), + Pad(2), + ToMode('CHW'), + ]), + sampler=RandomSampler(dataset=train_dataset, batch_size=64), + ) + return dataloader + + +def train(dataloader, args): + writer = SummaryWriter("runs") + net = Net() + net.train() + optimizer = SGD( + net.parameters(), lr=args.lr, + momentum=args.momentum, weight_decay=args.wd + ) + gm = GradManager().attach(net.parameters()) + + epoch_length = len(dataloader) + for epoch in range(args.epoch): + for step, (batch_data, batch_label) in enumerate(dataloader): + batch_label = batch_label.astype(np.int32) + data, label = mge.tensor(batch_data), mge.tensor(batch_label) + with gm: + pred = net(data) + loss = F.loss.cross_entropy(pred, label) + gm.backward(loss) + optimizer.step().clear_grad() + + if step % 50 == 0: + print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa + writer.add_scalar("loss", float(loss), epoch * epoch_length + step) + if (epoch + 1) % 5 == 0: + mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa + + +def main(): + task = Task.init(project_name='megengine', task_name='mge mnist train') # noqa + + parser = argparse.ArgumentParser(description='MegEngine MNIST Example') + parser.add_argument( + '--epoch', type=int, default=10, + help='number of training epoch(default: 10)', + ) + parser.add_argument( + '--lr', type=float, default=0.01, + help='learning rate(default: 0.01)' + ) + parser.add_argument( + '--momentum', type=float, default=0.9, + help='SGD momentum (default: 0.9)', + ) + parser.add_argument( + '--wd', type=float, default=5e-4, + help='SGD weight decay(default: 5e-4)', + ) + + args = parser.parse_args() + dataloader = build_dataloader() + train(dataloader, args) + + +if __name__ == "__main__": + main()