Add cifar ignite example and add auto extract of tar.gz files when using storagemanager (#237)

This commit is contained in:
erezalg 2020-11-11 16:35:23 +02:00 committed by GitHub
parent 6dd7b4e02e
commit 95ba6bab78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 217 additions and 14 deletions

View File

@ -0,0 +1,190 @@
from pathlib import Path
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from ignite.contrib.handlers import TensorboardLogger
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import global_step_from_engine
from ignite.metrics import Accuracy, Loss, Recall
from ignite.utils import setup_logger
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from trains import Task, StorageManager
# Trains Initializations
task = Task.init(project_name='Image Example', task_name='image classification CIFAR10')
params = {'number_of_epochs': 20, 'batch_size': 64, 'dropout': 0.25, 'base_lr': 0.001, 'momentum': 0.9, 'loss_report': 100}
params = task.connect(params) # enabling configuration override by trains
print(params) # printing actual configuration (after override in remote mode)
manager = StorageManager()
dataset_path = Path(manager.get_local_copy(remote_url="https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"))
# Dataset and Dataloader initializations
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.CIFAR10(root=dataset_path,
train=True,
download=False,
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=params.get('batch_size', 4),
shuffle=True,
num_workers=10)
testset = datasets.CIFAR10(root=dataset_path,
train=False,
download=False,
transform=transform)
testloader = torch.utils.data.DataLoader(testset,
batch_size=params.get('batch_size', 4),
shuffle=False,
num_workers=10)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
tb_logger = TensorboardLogger(log_dir="cifar-output")
# Helper function to store predictions and scores using matplotlib
def predictions_gt_images_handler(engine, logger, *args, **kwargs):
x, _ = engine.state.batch
y_pred, y = engine.state.output
num_x = num_y = 4
le = num_x * num_y
fig = plt.figure(figsize=(20, 20))
trans = transforms.ToPILImage()
for idx in range(le):
preds = torch.argmax(F.softmax(y_pred[idx],dim=0))
probs = torch.max(F.softmax(y_pred[idx],dim=0))
ax = fig.add_subplot(num_x, num_y, idx + 1, xticks=[], yticks=[])
ax.imshow(trans(x[idx]))
ax.set_title("{0} {1:.1f}% (label: {2})".format(
classes[preds],
probs * 100,
classes[y[idx]]),
color=("green" if preds == y[idx] else "red")
)
logger.writer.add_figure('predictions vs actuals', figure=fig, global_step=engine.state.epoch)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.dorpout = nn.Dropout(p=params.get('dropout', 0.25))
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 6 * 6)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(self.dorpout(x))
return x
# Training
def run(epochs, lr, momentum, log_interval):
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(net, optimizer, criterion, device=device)
trainer.logger = setup_logger("trainer")
val_metrics = {"accuracy": Accuracy(),"loss": Loss(criterion), "recall": Recall()}
evaluator = create_supervised_evaluator(net, metrics=val_metrics, device=device)
evaluator.logger = setup_logger("evaluator")
# Attach handler to plot trainer's loss every 100 iterations
tb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=params.get('loss_report')),
tag="training",
output_transform=lambda loss: {"loss": loss},
)
# Attach handler to dump evaluator's metrics every epoch completed
for tag, evaluator in [("training", trainer), ("validation", evaluator)]:
tb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag=tag,
metric_names="all",
global_step_transform=global_step_from_engine(trainer),
)
# Attach function to build debug images and report every epoch end
tb_logger.attach(
evaluator,
log_handler=predictions_gt_images_handler,
event_name=Events.EPOCH_COMPLETED(once=1),
);
desc = "ITERATION - loss: {:.2f}"
pbar = tqdm(initial=0, leave=False, total=len(trainloader), desc=desc.format(0))
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
pbar.desc = desc.format(engine.state.output)
pbar.update(log_interval)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(trainloader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["loss"]
tqdm.write(
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
engine.state.epoch, avg_accuracy, avg_nll
)
)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(testloader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["loss"]
tqdm.write(
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
engine.state.epoch, avg_accuracy, avg_nll
)
)
pbar.n = pbar.last_print_n = 0
@trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def log_time():
tqdm.write(
"{} took {} seconds".format(trainer.last_event_name.name, trainer.state.times[trainer.last_event_name.name])
)
trainer.run(trainloader, max_epochs=epochs)
pbar.close()
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
print('Finished Training')
print('Task ID number is: {}'.format(task.id))
run(params.get('number_of_epochs'), params.get('base_lr'), params.get('momentum'), 10)

View File

@ -1,5 +1,5 @@
import os
import shutil import shutil
import tarfile
from random import random from random import random
from time import time from time import time
from typing import Optional from typing import Optional
@ -7,9 +7,9 @@ from zipfile import ZipFile
from pathlib2 import Path from pathlib2 import Path
from .cache import CacheManager
from .util import encode_string_to_filename from .util import encode_string_to_filename
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from .cache import CacheManager
class StorageManager(object): class StorageManager(object):
@ -91,19 +91,27 @@ class StorageManager(object):
@classmethod @classmethod
def _extract_to_cache(cls, cached_file, name): def _extract_to_cache(cls, cached_file, name):
""" """
Extract cached file zip file to cache folder Extract cached file to cache folder
:param str cached_file: local copy of archive file :param str cached_file: local copy of archive file
:param str name: cache context :param str name: cache context
:return: cached folder containing the extracted archive content :return: cached folder containing the extracted archive content
""" """
# only zip files if not cached_file:
if not cached_file or not str(cached_file).lower().endswith('.zip'):
return cached_file return cached_file
cached_folder = Path(cached_file).parent cached_file = Path(cached_file)
archive_suffix = cached_file.rpartition(".")[0]
name = encode_string_to_filename(name) # we support zip and tar.gz files auto-extraction
target_folder = Path("{0}_artifacts_archive_{1}".format(archive_suffix, name)) if (
not cached_file.suffix == ".zip"
and not cached_file.suffixes[-2:] == [".tar", ".gz"]
):
return str(cached_file)
cached_folder = cached_file.parent
name = encode_string_to_filename(name) if name else name
target_folder = Path("{0}/{1}_artifacts_archive_{2}".format(cached_folder, cached_file.stem, name))
if target_folder.exists(): if target_folder.exists():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -117,11 +125,16 @@ class StorageManager(object):
temp_target_folder = cached_folder / "{0}_{1}_{2}".format( temp_target_folder = cached_folder / "{0}_{1}_{2}".format(
target_folder.name, time() * 1000, str(random()).replace('.', '')) target_folder.name, time() * 1000, str(random()).replace('.', ''))
temp_target_folder.mkdir(parents=True, exist_ok=True) temp_target_folder.mkdir(parents=True, exist_ok=True)
ZipFile(cached_file).extractall(path=temp_target_folder.as_posix()) if cached_file.suffix == ".zip":
# we assume we will have such folder if we already extract the zip file ZipFile(cached_file).extractall(path=temp_target_folder.as_posix())
elif cached_file.suffixes[-2:] == [".tar", ".gz"]:
with tarfile.open(cached_file) as file:
file.extractall(temp_target_folder)
# we assume we will have such folder if we already extract the file
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# if rename fails, it means that someone else already manged to extract the zip, delete the current # if rename fails, it means that someone else already manged to extract the file, delete the current
# folder and return the already existing cached zip folder # folder and return the already existing cached zip folder
shutil.move(temp_target_folder.as_posix(), target_folder.as_posix()) shutil.move(temp_target_folder.as_posix(), target_folder.as_posix())
except Exception: except Exception:
@ -142,9 +155,9 @@ class StorageManager(object):
) )
) )
except Exception as ex: except Exception as ex:
# failed extracting zip file: # failed extracting the file:
base_logger.warning( base_logger.warning(
"Exception {}\nFailed extracting zip file {}".format(ex, cached_file) "Exception {}\nFailed extracting zip file {}".format(ex, str(cached_file))
) )
# noinspection PyBroadException # noinspection PyBroadException
try: try: