Fix model names in examples are inconsistent

This commit is contained in:
allegroai 2023-06-11 13:58:01 +03:00
parent 1ccdff5e77
commit 42320421a2
3 changed files with 11 additions and 6 deletions

View File

@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
""" """
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os import os
from tempfile import gettempdir from tempfile import gettempdir
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms from torchvision import datasets, transforms
from clearml import Task, Logger from clearml import Task, Logger
@ -51,7 +54,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item()) "train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) 100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader, epoch): def test(args, model, device, test_loader, epoch):
@ -127,8 +130,9 @@ def main():
task.execute_remotely(queue_name="default") task.execute_remotely(queue_name="default")
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch) test(args, model, device, test_loader, epoch)
if (args.save_model):
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -156,7 +156,7 @@ def main(_):
test(FLAGS, model, device, test_loader, epoch) test(FLAGS, model, device, test_loader, epoch)
if FLAGS.save_model: if FLAGS.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
# ClearML - Example of Pytorch mnist training integration # ClearML - Example of Pytorch mnist training integration
# #
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os import os
from tempfile import gettempdir from tempfile import gettempdir
@ -47,7 +48,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item()) "train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) 100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader, epoch): def test(args, model, device, test_loader, epoch):
@ -128,7 +129,7 @@ def main():
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch) test(args, model, device, test_loader, epoch)
if (args.save_model): if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))