mirror of
https://github.com/clearml/clearml
synced 2025-05-08 22:59:24 +00:00
Fix model names in examples are inconsistent
This commit is contained in:
parent
1ccdff5e77
commit
42320421a2
@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
from clearml import Task, Logger
|
from clearml import Task, Logger
|
||||||
|
|
||||||
|
|
||||||
@ -127,8 +130,9 @@ def main():
|
|||||||
task.execute_remotely(queue_name="default")
|
task.execute_remotely(queue_name="default")
|
||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
if (args.save_model):
|
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
if args.save_model:
|
||||||
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -156,7 +156,7 @@ def main(_):
|
|||||||
test(FLAGS, model, device, test_loader, epoch)
|
test(FLAGS, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if FLAGS.save_model:
|
if FLAGS.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# ClearML - Example of Pytorch mnist training integration
|
# ClearML - Example of Pytorch mnist training integration
|
||||||
#
|
#
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
@ -128,7 +129,7 @@ def main():
|
|||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if (args.save_model):
|
if args.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user