mirror of
https://github.com/clearml/clearml
synced 2025-05-08 14:54:28 +00:00
Fix model names in examples are inconsistent
This commit is contained in:
parent
1ccdff5e77
commit
42320421a2
@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
from clearml import Task, Logger
|
from clearml import Task, Logger
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +54,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
|
|||||||
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
||||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||||
100. * batch_idx / len(train_loader), loss.item()))
|
100. * batch_idx / len(train_loader), loss.item()))
|
||||||
|
|
||||||
|
|
||||||
def test(args, model, device, test_loader, epoch):
|
def test(args, model, device, test_loader, epoch):
|
||||||
@ -127,8 +130,9 @@ def main():
|
|||||||
task.execute_remotely(queue_name="default")
|
task.execute_remotely(queue_name="default")
|
||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
if (args.save_model):
|
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
if args.save_model:
|
||||||
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -156,7 +156,7 @@ def main(_):
|
|||||||
test(FLAGS, model, device, test_loader, epoch)
|
test(FLAGS, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if FLAGS.save_model:
|
if FLAGS.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# ClearML - Example of Pytorch mnist training integration
|
# ClearML - Example of Pytorch mnist training integration
|
||||||
#
|
#
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
@ -47,7 +48,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
|
|||||||
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
||||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||||
100. * batch_idx / len(train_loader), loss.item()))
|
100. * batch_idx / len(train_loader), loss.item()))
|
||||||
|
|
||||||
|
|
||||||
def test(args, model, device, test_loader, epoch):
|
def test(args, model, device, test_loader, epoch):
|
||||||
@ -128,7 +129,7 @@ def main():
|
|||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if (args.save_model):
|
if args.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user