mirror of
https://github.com/clearml/clearml
synced 2025-06-08 23:57:16 +00:00
Fixed requirements, refactored and formatted code in some examples (#567)
This commit is contained in:
parent
172c3e44f1
commit
eb5350f551
@ -1,23 +1,28 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import megengine as mge
|
try:
|
||||||
import megengine.module as M
|
import megengine as mge
|
||||||
import megengine.functional as F
|
import megengine.functional as F
|
||||||
from megengine.optimizer import SGD
|
import megengine.module as M
|
||||||
from megengine.autodiff import GradManager
|
from megengine.autodiff import GradManager
|
||||||
|
from megengine.data import DataLoader, RandomSampler
|
||||||
|
from megengine.data.dataset import MNIST
|
||||||
|
from megengine.data.transform import Compose, Normalize, Pad, ToMode
|
||||||
|
from megengine.optimizer import SGD
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"megengine package is missing, you can install it using pip: pip install megengine"
|
||||||
|
if sys.version_info.minor <= 8
|
||||||
|
else "MegEngine does not support python version >= 3.9"
|
||||||
|
)
|
||||||
|
|
||||||
from megengine.data import DataLoader, RandomSampler
|
|
||||||
from megengine.data.transform import ToMode, Pad, Normalize, Compose
|
|
||||||
from megengine.data.dataset import MNIST
|
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
class Net(M.Module):
|
class Net(M.Module):
|
||||||
@ -55,11 +60,7 @@ def build_dataloader():
|
|||||||
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
train_dataset = MNIST(root=gettempdir(), train=True, download=True)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
transform=Compose([
|
transform=Compose([Normalize(mean=0.1307 * 255, std=0.3081 * 255), Pad(2), ToMode("CHW"),]),
|
||||||
Normalize(mean=0.1307*255, std=0.3081*255),
|
|
||||||
Pad(2),
|
|
||||||
ToMode('CHW'),
|
|
||||||
]),
|
|
||||||
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
sampler=RandomSampler(dataset=train_dataset, batch_size=64),
|
||||||
)
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
@ -69,10 +70,7 @@ def train(dataloader, args):
|
|||||||
writer = SummaryWriter("runs")
|
writer = SummaryWriter("runs")
|
||||||
net = Net()
|
net = Net()
|
||||||
net.train()
|
net.train()
|
||||||
optimizer = SGD(
|
optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
|
||||||
net.parameters(), lr=args.lr,
|
|
||||||
momentum=args.momentum, weight_decay=args.wd
|
|
||||||
)
|
|
||||||
gm = GradManager().attach(net.parameters())
|
gm = GradManager().attach(net.parameters())
|
||||||
|
|
||||||
epoch_length = len(dataloader)
|
epoch_length = len(dataloader)
|
||||||
@ -90,28 +88,24 @@ def train(dataloader, args):
|
|||||||
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss))) # noqa
|
||||||
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
mge.save(net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl")) # noqa
|
mge.save(
|
||||||
|
net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
task = Task.init(project_name='examples', task_name='megengine mnist train') # noqa
|
task = Task.init(project_name="examples", task_name="megengine mnist train") # noqa
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MegEngine MNIST Example')
|
parser = argparse.ArgumentParser(description="MegEngine MNIST Example")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--epoch', type=int, default=10,
|
"--epoch", type=int, default=10, help="number of training epoch(default: 10)",
|
||||||
help='number of training epoch(default: 10)',
|
)
|
||||||
|
parser.add_argument("--lr", type=float, default=0.01, help="learning rate(default: 0.01)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--lr', type=float, default=0.01,
|
"--wd", type=float, default=5e-4, help="SGD weight decay(default: 5e-4)",
|
||||||
help='learning rate(default: 0.01)'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--momentum', type=float, default=0.9,
|
|
||||||
help='SGD momentum (default: 0.9)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--wd', type=float, default=5e-4,
|
|
||||||
help='SGD weight decay(default: 5e-4)',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
MegEngine
|
MegEngine ; python_version < '3.9'
|
||||||
tensorboardX
|
tensorboardX
|
||||||
clearml
|
clearml
|
@ -3,4 +3,5 @@ tensorboardX
|
|||||||
tensorboard>=1.14.0
|
tensorboard>=1.14.0
|
||||||
torch>=1.1.0
|
torch>=1.1.0
|
||||||
torchvision>=0.3.0
|
torchvision>=0.3.0
|
||||||
|
tqdm
|
||||||
clearml
|
clearml
|
Loading…
Reference in New Issue
Block a user