mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Fixed requirements, refactored and formatted code in some examples (#567)
This commit is contained in:
parent
172c3e44f1
commit
eb5350f551
@ -1,23 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from tempfile import gettempdir
|
||||
|
||||
import numpy as np
|
||||
|
||||
import megengine as mge
|
||||
import megengine.module as M
|
||||
import megengine.functional as F
|
||||
from megengine.optimizer import SGD
|
||||
from megengine.autodiff import GradManager
|
||||
try:
|
||||
import megengine as mge
|
||||
import megengine.functional as F
|
||||
import megengine.module as M
|
||||
from megengine.autodiff import GradManager
|
||||
from megengine.data import DataLoader, RandomSampler
|
||||
from megengine.data.dataset import MNIST
|
||||
from megengine.data.transform import Compose, Normalize, Pad, ToMode
|
||||
from megengine.optimizer import SGD
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"megengine package is missing, you can install it using pip: pip install megengine"
|
||||
if sys.version_info.minor <= 8
|
||||
else "MegEngine does not support python version >= 3.9"
|
||||
)
|
||||
|
||||
from megengine.data import DataLoader, RandomSampler
|
||||
from megengine.data.transform import ToMode, Pad, Normalize, Compose
|
||||
from megengine.data.dataset import MNIST
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from clearml import Task
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
class Net(M.Module):
|
||||
@ -55,11 +60,7 @@ def build_dataloader():
|
||||
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
transform=Compose([
|
||||
Normalize(mean=0.1307*255, std=0.3081*255),
|
||||
Pad(2),
|
||||
ToMode('CHW'),
|
||||
]),
|
||||
transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]),
|
||||
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
||||
)
|
||||
return dataloader
|
||||
@ -69,10 +70,7 @@ def train(dataloader, args):
|
||||
writer = SummaryWriter("runs")
|
||||
net = Net()
|
||||
net.train()
|
||||
optimizer = SGD(
|
||||
net.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.wd
|
||||
)
|
||||
optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
|
||||
gm = GradManager().attach(net.parameters())
|
||||
|
||||
epoch_length = len(dataloader)
|
||||
@ -90,28 +88,24 @@ def train(dataloader, args):
|
||||
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
||||
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
||||
if (epoch + 1) % 5 == 0:
|
||||
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
|
||||
mge.save(
|
||||
net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
|
||||
) # noqa
|
||||
|
||||
|
||||
def main():
|
||||
task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa
|
||||
task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa
|
||||
|
||||
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
|
||||
parser = argparse.ArgumentParser(description="MegEngine MNIST Example")
|
||||
parser.add_argument(
|
||||
'--epoch', type=int, default=10,
|
||||
help='number of training epoch(default: 10)',
|
||||
"--epoch", type=int, default=10, help="number of training epoch(default: 10)",
|
||||
)
|
||||
parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)")
|
||||
parser.add_argument(
|
||||
"--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr', type=float, default=0.01,
|
||||
help='learning rate(default: 0.01)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--momentum', type=float, default=0.9,
|
||||
help='SGD momentum (default: 0.9)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--wd', type=float, default=5e-4,
|
||||
help='SGD weight decay(default: 5e-4)',
|
||||
"--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -1,3 +1,3 @@
|
||||
MegEngine
|
||||
MegEngine ; python_version < '3.9'
|
||||
tensorboardX
|
||||
clearml
|
||||
clearml
|
||||
|
@ -3,4 +3,5 @@ tensorboardX
|
||||
tensorboard>=1.14.0
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
tqdm
|
||||
clearml
|
Loading…
Reference in New Issue
Block a user