Fix model names in examples are inconsistent

This commit is contained in:
allegroai 2023-06-11 13:58:01 +03:00
parent 1ccdff5e77
commit 42320421a2
3 changed files with 11 additions and 6 deletions

View File

@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
"""
from __future__ import print_function
import argparse
import os
from tempfile import gettempdir
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from clearml import Task, Logger
@ -127,8 +130,9 @@ def main():
task.execute_remotely(queue_name="default")
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch)
if (args.save_model):
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
if __name__ == '__main__':

View File

@ -156,7 +156,7 @@ def main(_):
test(FLAGS, model, device, test_loader, epoch)
if FLAGS.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
# ClearML - Example of Pytorch mnist training integration
#
from __future__ import print_function
import argparse
import os
from tempfile import gettempdir
@ -128,7 +129,7 @@ def main():
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch)
if (args.save_model):
if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))