mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add cifar ignite example and add auto extract of tar.gz files when using storagemanager (#237)
This commit is contained in:
parent
6dd7b4e02e
commit
95ba6bab78
190
examples/frameworks/ignite/cifar_ignite.py
Normal file
190
examples/frameworks/ignite/cifar_ignite.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision.datasets as datasets
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from ignite.contrib.handlers import TensorboardLogger
|
||||||
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
|
||||||
|
from ignite.handlers import global_step_from_engine
|
||||||
|
from ignite.metrics import Accuracy, Loss, Recall
|
||||||
|
from ignite.utils import setup_logger
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from trains import Task, StorageManager
|
||||||
|
|
||||||
|
# Trains Initializations
|
||||||
|
task = Task.init(project_name='Image Example', task_name='image classification CIFAR10')
|
||||||
|
params = {'number_of_epochs': 20, 'batch_size': 64, 'dropout': 0.25, 'base_lr': 0.001, 'momentum': 0.9, 'loss_report': 100}
|
||||||
|
params = task.connect(params) # enabling configuration override by trains
|
||||||
|
print(params) # printing actual configuration (after override in remote mode)
|
||||||
|
|
||||||
|
manager = StorageManager()
|
||||||
|
|
||||||
|
dataset_path = Path(manager.get_local_copy(remote_url="https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"))
|
||||||
|
|
||||||
|
# Dataset and Dataloader initializations
|
||||||
|
transform = transforms.Compose([transforms.ToTensor()])
|
||||||
|
|
||||||
|
trainset = datasets.CIFAR10(root=dataset_path,
|
||||||
|
train=True,
|
||||||
|
download=False,
|
||||||
|
transform=transform)
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset,
|
||||||
|
batch_size=params.get('batch_size', 4),
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=10)
|
||||||
|
|
||||||
|
testset = datasets.CIFAR10(root=dataset_path,
|
||||||
|
train=False,
|
||||||
|
download=False,
|
||||||
|
transform=transform)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset,
|
||||||
|
batch_size=params.get('batch_size', 4),
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=10)
|
||||||
|
|
||||||
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||||
|
|
||||||
|
tb_logger = TensorboardLogger(log_dir="cifar-output")
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to store predictions and scores using matplotlib
|
||||||
|
def predictions_gt_images_handler(engine, logger, *args, **kwargs):
|
||||||
|
x, _ = engine.state.batch
|
||||||
|
y_pred, y = engine.state.output
|
||||||
|
|
||||||
|
num_x = num_y = 4
|
||||||
|
le = num_x * num_y
|
||||||
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
trans = transforms.ToPILImage()
|
||||||
|
for idx in range(le):
|
||||||
|
preds = torch.argmax(F.softmax(y_pred[idx],dim=0))
|
||||||
|
probs = torch.max(F.softmax(y_pred[idx],dim=0))
|
||||||
|
ax = fig.add_subplot(num_x, num_y, idx + 1, xticks=[], yticks=[])
|
||||||
|
ax.imshow(trans(x[idx]))
|
||||||
|
ax.set_title("{0} {1:.1f}% (label: {2})".format(
|
||||||
|
classes[preds],
|
||||||
|
probs * 100,
|
||||||
|
classes[y[idx]]),
|
||||||
|
color=("green" if preds == y[idx] else "red")
|
||||||
|
)
|
||||||
|
logger.writer.add_figure('predictions vs actuals', figure=fig, global_step=engine.state.epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 6, 3)
|
||||||
|
self.conv2 = nn.Conv2d(6, 16, 3)
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
self.fc1 = nn.Linear(16 * 6 * 6, 120)
|
||||||
|
self.fc2 = nn.Linear(120, 84)
|
||||||
|
self.dorpout = nn.Dropout(p=params.get('dropout', 0.25))
|
||||||
|
self.fc3 = nn.Linear(84, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(F.relu(self.conv1(x)))
|
||||||
|
x = self.pool(F.relu(self.conv2(x)))
|
||||||
|
x = x.view(-1, 16 * 6 * 6)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
x = self.fc3(self.dorpout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Training
|
||||||
|
def run(epochs, lr, momentum, log_interval):
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
net = Net().to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
|
||||||
|
|
||||||
|
trainer = create_supervised_trainer(net, optimizer, criterion, device=device)
|
||||||
|
trainer.logger = setup_logger("trainer")
|
||||||
|
|
||||||
|
val_metrics = {"accuracy": Accuracy(),"loss": Loss(criterion), "recall": Recall()}
|
||||||
|
evaluator = create_supervised_evaluator(net, metrics=val_metrics, device=device)
|
||||||
|
evaluator.logger = setup_logger("evaluator")
|
||||||
|
|
||||||
|
# Attach handler to plot trainer's loss every 100 iterations
|
||||||
|
tb_logger.attach_output_handler(
|
||||||
|
trainer,
|
||||||
|
event_name=Events.ITERATION_COMPLETED(every=params.get('loss_report')),
|
||||||
|
tag="training",
|
||||||
|
output_transform=lambda loss: {"loss": loss},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attach handler to dump evaluator's metrics every epoch completed
|
||||||
|
for tag, evaluator in [("training", trainer), ("validation", evaluator)]:
|
||||||
|
tb_logger.attach_output_handler(
|
||||||
|
evaluator,
|
||||||
|
event_name=Events.EPOCH_COMPLETED,
|
||||||
|
tag=tag,
|
||||||
|
metric_names="all",
|
||||||
|
global_step_transform=global_step_from_engine(trainer),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attach function to build debug images and report every epoch end
|
||||||
|
tb_logger.attach(
|
||||||
|
evaluator,
|
||||||
|
log_handler=predictions_gt_images_handler,
|
||||||
|
event_name=Events.EPOCH_COMPLETED(once=1),
|
||||||
|
);
|
||||||
|
|
||||||
|
desc = "ITERATION - loss: {:.2f}"
|
||||||
|
pbar = tqdm(initial=0, leave=False, total=len(trainloader), desc=desc.format(0))
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
|
||||||
|
def log_training_loss(engine):
|
||||||
|
pbar.desc = desc.format(engine.state.output)
|
||||||
|
pbar.update(log_interval)
|
||||||
|
|
||||||
|
@trainer.on(Events.EPOCH_COMPLETED)
|
||||||
|
def log_training_results(engine):
|
||||||
|
pbar.refresh()
|
||||||
|
evaluator.run(trainloader)
|
||||||
|
metrics = evaluator.state.metrics
|
||||||
|
avg_accuracy = metrics["accuracy"]
|
||||||
|
avg_nll = metrics["loss"]
|
||||||
|
tqdm.write(
|
||||||
|
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
|
||||||
|
engine.state.epoch, avg_accuracy, avg_nll
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@trainer.on(Events.EPOCH_COMPLETED)
|
||||||
|
def log_validation_results(engine):
|
||||||
|
evaluator.run(testloader)
|
||||||
|
metrics = evaluator.state.metrics
|
||||||
|
avg_accuracy = metrics["accuracy"]
|
||||||
|
avg_nll = metrics["loss"]
|
||||||
|
tqdm.write(
|
||||||
|
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
|
||||||
|
engine.state.epoch, avg_accuracy, avg_nll
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pbar.n = pbar.last_print_n = 0
|
||||||
|
|
||||||
|
@trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
|
||||||
|
def log_time():
|
||||||
|
tqdm.write(
|
||||||
|
"{} took {} seconds".format(trainer.last_event_name.name, trainer.state.times[trainer.last_event_name.name])
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.run(trainloader, max_epochs=epochs)
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
PATH = './cifar_net.pth'
|
||||||
|
torch.save(net.state_dict(), PATH)
|
||||||
|
|
||||||
|
print('Finished Training')
|
||||||
|
print('Task ID number is: {}'.format(task.id))
|
||||||
|
|
||||||
|
|
||||||
|
run(params.get('number_of_epochs'), params.get('base_lr'), params.get('momentum'), 10)
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import tarfile
|
||||||
from random import random
|
from random import random
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -7,9 +7,9 @@ from zipfile import ZipFile
|
|||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from .cache import CacheManager
|
||||||
from .util import encode_string_to_filename
|
from .util import encode_string_to_filename
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
from .cache import CacheManager
|
|
||||||
|
|
||||||
|
|
||||||
class StorageManager(object):
|
class StorageManager(object):
|
||||||
@ -91,19 +91,27 @@ class StorageManager(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _extract_to_cache(cls, cached_file, name):
|
def _extract_to_cache(cls, cached_file, name):
|
||||||
"""
|
"""
|
||||||
Extract cached file zip file to cache folder
|
Extract cached file to cache folder
|
||||||
:param str cached_file: local copy of archive file
|
:param str cached_file: local copy of archive file
|
||||||
:param str name: cache context
|
:param str name: cache context
|
||||||
:return: cached folder containing the extracted archive content
|
:return: cached folder containing the extracted archive content
|
||||||
"""
|
"""
|
||||||
# only zip files
|
if not cached_file:
|
||||||
if not cached_file or not str(cached_file).lower().endswith('.zip'):
|
|
||||||
return cached_file
|
return cached_file
|
||||||
|
|
||||||
cached_folder = Path(cached_file).parent
|
cached_file = Path(cached_file)
|
||||||
archive_suffix = cached_file.rpartition(".")[0]
|
|
||||||
name = encode_string_to_filename(name)
|
# we support zip and tar.gz files auto-extraction
|
||||||
target_folder = Path("{0}_artifacts_archive_{1}".format(archive_suffix, name))
|
if (
|
||||||
|
not cached_file.suffix == ".zip"
|
||||||
|
and not cached_file.suffixes[-2:] == [".tar", ".gz"]
|
||||||
|
):
|
||||||
|
return str(cached_file)
|
||||||
|
|
||||||
|
cached_folder = cached_file.parent
|
||||||
|
|
||||||
|
name = encode_string_to_filename(name) if name else name
|
||||||
|
target_folder = Path("{0}/{1}_artifacts_archive_{2}".format(cached_folder, cached_file.stem, name))
|
||||||
if target_folder.exists():
|
if target_folder.exists():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -117,11 +125,16 @@ class StorageManager(object):
|
|||||||
temp_target_folder = cached_folder / "{0}_{1}_{2}".format(
|
temp_target_folder = cached_folder / "{0}_{1}_{2}".format(
|
||||||
target_folder.name, time() * 1000, str(random()).replace('.', ''))
|
target_folder.name, time() * 1000, str(random()).replace('.', ''))
|
||||||
temp_target_folder.mkdir(parents=True, exist_ok=True)
|
temp_target_folder.mkdir(parents=True, exist_ok=True)
|
||||||
ZipFile(cached_file).extractall(path=temp_target_folder.as_posix())
|
if cached_file.suffix == ".zip":
|
||||||
# we assume we will have such folder if we already extract the zip file
|
ZipFile(cached_file).extractall(path=temp_target_folder.as_posix())
|
||||||
|
elif cached_file.suffixes[-2:] == [".tar", ".gz"]:
|
||||||
|
with tarfile.open(cached_file) as file:
|
||||||
|
file.extractall(temp_target_folder)
|
||||||
|
|
||||||
|
# we assume we will have such folder if we already extract the file
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# if rename fails, it means that someone else already manged to extract the zip, delete the current
|
# if rename fails, it means that someone else already manged to extract the file, delete the current
|
||||||
# folder and return the already existing cached zip folder
|
# folder and return the already existing cached zip folder
|
||||||
shutil.move(temp_target_folder.as_posix(), target_folder.as_posix())
|
shutil.move(temp_target_folder.as_posix(), target_folder.as_posix())
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -142,9 +155,9 @@ class StorageManager(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
# failed extracting zip file:
|
# failed extracting the file:
|
||||||
base_logger.warning(
|
base_logger.warning(
|
||||||
"Exception {}\nFailed extracting zip file {}".format(ex, cached_file)
|
"Exception {}\nFailed extracting zip file {}".format(ex, str(cached_file))
|
||||||
)
|
)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user