diff --git a/examples/frameworks/megengine/megengine_mnist.py b/examples/frameworks/megengine/megengine_mnist.py index 3f5f9d58..83255222 100644 --- a/examples/frameworks/megengine/megengine_mnist.py +++ b/examples/frameworks/megengine/megengine_mnist.py @@ -1,23 +1,28 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- - import argparse import os +import sys from tempfile import gettempdir + import numpy as np -import megengine as mge -import megengine.module as M -import megengine.functional as F -from megengine.optimizer import SGD -from megengine.autodiff import GradManager +try: + import megengine as mge + import megengine.functional as F + import megengine.module as M + from megengine.autodiff import GradManager + from megengine.data import DataLoader, RandomSampler + from megengine.data.dataset import MNIST + from megengine.data.transform import Compose, Normalize, Pad, ToMode + from megengine.optimizer import SGD +except ImportError: + raise ImportError( + "megengine package is missing, you can install it using pip: pip install megengine" + if sys.version_info.minor <= 8 + else "MegEngine does not support python version >= 3.9" + ) -from megengine.data import DataLoader, RandomSampler -from megengine.data.transform import ToMode, Pad, Normalize, Compose -from megengine.data.dataset import MNIST - -from tensorboardX import SummaryWriter from clearml import Task +from tensorboardX import SummaryWriter class Net(M.Module): @@ -55,11 +60,7 @@ def build_dataloader(): train_dataset = MNIST(root=gettempdir(), train=True, download=True) dataloader = DataLoader( train_dataset, - transform=Compose([ - Normalize(mean=0.1307*255, std=0.3081*255), - Pad(2), - ToMode('CHW'), - ]), + transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]), sampler=RandomSampler(dataset=train_dataset, batch_size=64), ) return dataloader @@ -69,10 +70,7 @@ def train(dataloader, args): writer = SummaryWriter("runs") net = Net() net.train() - optimizer = SGD( - net.parameters(), lr=args.lr, - momentum=args.momentum, weight_decay=args.wd - ) + optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) gm = GradManager().attach(net.parameters()) epoch_length = len(dataloader) @@ -90,28 +88,24 @@ def train(dataloader, args): print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa writer.add_scalar("loss", float(loss), epoch * epoch_length + step) if (epoch + 1) % 5 == 0: - mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa + mge.save( + net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"), + ) # noqa def main(): - task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa + task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa - parser = argparse.ArgumentParser(description='MegEngine MNIST Example') + parser = argparse.ArgumentParser(description="MegEngine MNIST Example") parser.add_argument( - '--epoch', type=int, default=10, - help='number of training epoch(default: 10)', + "--epoch", type=int, default=10, help="number of training epoch(default: 10)", + ) + parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)") + parser.add_argument( + "--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)", ) parser.add_argument( - '--lr', type=float, default=0.01, - help='learning rate(default: 0.01)' - ) - parser.add_argument( - '--momentum', type=float, default=0.9, - help='SGD momentum (default: 0.9)', - ) - parser.add_argument( - '--wd', type=float, default=5e-4, - help='SGD weight decay(default: 5e-4)', + "--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)", ) args = parser.parse_args() diff --git a/examples/frameworks/megengine/requirements.txt b/examples/frameworks/megengine/requirements.txt index e3c916c0..75687027 100644 --- a/examples/frameworks/megengine/requirements.txt +++ b/examples/frameworks/megengine/requirements.txt @@ -1,3 +1,3 @@ -MegEngine +MegEngine ; python_version < '3.9' tensorboardX -clearml \ No newline at end of file +clearml diff --git a/examples/reporting/model_update_pytorch.py b/examples/frameworks/pytorch/pytorch_model_update.py similarity index 100% rename from examples/reporting/model_update_pytorch.py rename to examples/frameworks/pytorch/pytorch_model_update.py diff --git a/examples/frameworks/pytorch/requirements.txt b/examples/frameworks/pytorch/requirements.txt index d888329b..3a716022 100644 --- a/examples/frameworks/pytorch/requirements.txt +++ b/examples/frameworks/pytorch/requirements.txt @@ -3,4 +3,5 @@ tensorboardX tensorboard>=1.14.0 torch>=1.1.0 torchvision>=0.3.0 +tqdm clearml \ No newline at end of file