Support fastai v2 (#571)

This commit is contained in:
eugen-ajechiloae-clearml 2022-02-17 00:01:55 +02:00 committed by GitHub
parent 16b009f1ff
commit a8fbe51231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 437 additions and 69 deletions

View File

@ -2,50 +2,77 @@ import sys
import numpy as np
from packaging import version
from . import _patched_call
from .tensorflow_bind import WeightsGradientHistHelper
from ..import_bind import PostImportHookPatching
from ...debugging.log import LoggerRoot
try:
import fastai
except ImportError:
fastai = None
class PatchFastai(object):
__metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID
__main_task = None
@staticmethod
def update_current_task(task, **_):
if fastai is None:
return
# noinspection PyBroadException
try:
if version.parse(fastai.__version__) < version.parse("2.0.0"):
PatchFastaiV1.update_current_task(task)
PatchFastaiV1.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
else:
PatchFastaiV2.update_current_task(task)
PatchFastaiV2.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
except Exception:
pass
class PatchFastaiV1(object):
__metrics_names = {}
__gradient_hist_helpers = {}
_main_task = None
@staticmethod
def update_current_task(task, **_):
PatchFastai.__main_task = task
PatchFastai._patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback)
PatchFastaiV1._main_task = task
@staticmethod
def _patch_model_callback():
# if you have tensroboard, we assume you use TesnorboardLogger, which we catch, so no need to patch.
def patch_model_callback():
# if you have tensorboard, we assume you use TensorboardLogger, which we catch, so no need to patch.
if "tensorboard" in sys.modules:
return
if "fastai" in sys.modules:
try:
from fastai.basic_train import Recorder
try:
from fastai.basic_train import Recorder
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end)
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end)
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end)
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin)
except ImportError:
pass
except Exception as ex:
LoggerRoot.get_base_logger(PatchFastai).debug(str(ex))
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastaiV1._on_train_begin)
except ImportError:
pass
except Exception as ex:
LoggerRoot.get_base_logger(PatchFastaiV1).debug(str(ex))
@staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastai.__main_task:
if not PatchFastaiV1._main_task:
return
# noinspection PyBroadException
try:
PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
PatchFastai.__metrics_names += recorder.metrics_names
PatchFastaiV1.__metrics_names[id(recorder)] = (
["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
)
PatchFastaiV1.__metrics_names[id(recorder)] += recorder.metrics_names
except Exception:
pass
@ -53,25 +80,26 @@ class PatchFastai(object):
def _on_backward_end(original_fn, recorder, *args, **kwargs):
def count_zeros(gradient):
n = gradient.data.data.cpu().numpy()
return n.size - n.count_nonzero()
return n.size - np.count_nonzero(n)
original_fn(recorder, *args, **kwargs)
if not PatchFastai.__main_task:
if not PatchFastaiV1._main_task:
return
# noinspection PyBroadException
try:
gradients = [
x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None
]
gradients = [x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None]
if len(gradients) == 0:
return
# TODO: Check computation!
gradient_stats = np.array([
(x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min())
for x in gradients])
gradient_stats = np.array(
[
(x.data.norm(), count_zeros(x), x.data.mean(), np.median(x.data), x.data.max(), x.data.min())
for x in gradients
]
)
stats_report = dict(
avg_norm=np.mean(gradient_stats[:, 0]),
median_norm=np.median(gradient_stats[:, 0]),
@ -79,12 +107,12 @@ class PatchFastai(object):
min_norm=np.min(gradient_stats[:, 0]),
num_zeros=gradient_stats[:, 1].sum(),
avg_gradient=gradient_stats[:, 2].mean(),
median_gradient=gradient_stats[:, 3].median(),
median_gradient=np.median(gradient_stats[:, 3]),
max_gradient=gradient_stats[:, 4].max(),
min_gradient=gradient_stats[:, 5].min(),
)
logger = PatchFastai.__main_task.get_logger()
logger = PatchFastaiV1._main_task.get_logger()
iteration = kwargs.get("iteration", 0)
for name, val in stats_report.items():
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
@ -94,52 +122,193 @@ class PatchFastai(object):
@staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastai.__main_task:
if not PatchFastaiV1._main_task:
return
# noinspection PyBroadException
try:
logger = PatchFastai.__main_task.get_logger()
logger = PatchFastaiV1._main_task.get_logger()
iteration = kwargs.get("iteration")
for series, value in zip(
PatchFastai.__metrics_names,
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
PatchFastaiV1.__metrics_names[id(recorder)],
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
):
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
PatchFastai.__main_task.flush()
PatchFastaiV1._main_task.flush()
except Exception:
pass
@staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastai.__main_task:
if not PatchFastaiV1._main_task:
return
# noinspection PyBroadException
try:
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
iteration = kwargs.get("iteration", 0)
if iteration == 0 or not kwargs.get("train"):
return
logger = PatchFastai.__main_task.get_logger()
logger = PatchFastaiV1._main_task.get_logger()
logger.report_scalar(
title="metrics",
series="train_loss",
value=kwargs.get("last_loss", 0),
iteration=kwargs.get("iteration", 0)
iteration=iteration,
)
gradient_hist_helper = WeightsGradientHistHelper(logger)
iteration = kwargs.get("iteration")
params = [
(name, values.clone().detach().cpu())
for (name, values) in recorder.model.named_parameters()
]
params = [(name, values.clone().detach().cpu()) for (name, values) in recorder.model.named_parameters()]
if (
id(recorder) not in PatchFastaiV1.__gradient_hist_helpers
or PatchFastaiV1.__gradient_hist_helpers[id(recorder)].logger is not logger
):
PatchFastaiV1.__gradient_hist_helpers[id(recorder)] = WeightsGradientHistHelper(logger)
histograms = []
for (name, values) in params:
gradient_hist_helper.add_histogram(
title="model_weights",
series="model_weights/" + name,
step=iteration,
hist_data=values,
histograms.append(
dict(title="model_weights", series="model_weights/" + name, step=iteration, hist_data=values)
)
PatchFastaiV1.__gradient_hist_helpers[id(recorder)].add_histograms(histograms)
except Exception:
pass
class PatchFastaiV2(object):
_main_task = None
@staticmethod
def update_current_task(task, **_):
PatchFastaiV2._main_task = task
@staticmethod
def patch_model_callback():
if "tensorboard" in sys.modules:
return
# noinspection PyBroadException
try:
fastai.learner.Learner.fit = _patched_call(fastai.learner.Learner.fit, PatchFastaiV2._insert_callbacks)
except Exception:
pass
try:
from fastai.learner import Recorder
__patch_fastai_callbacks_base = Recorder
except ImportError:
__patch_fastai_callbacks_base = object
class PatchFastaiCallbacks(__patch_fastai_callbacks_base):
__id = 0
def __init__(self, *args, **kwargs):
kwargs["train_metrics"] = True
super().__init__(*args, **kwargs)
self.__train_iter = 0
def noop(*_, **__):
pass
self.logger = noop
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger())
def after_batch(self):
# noinspection PyBroadException
try:
super().after_batch() # noqa
logger = PatchFastaiV2._main_task.get_logger()
if not self.training: # noqa
return
self.__train_iter += 1
for metric in self._train_mets: # noqa
logger.report_scalar(
title="metrics_" + self.__id,
series="train_" + metric.name,
value=metric.value,
iteration=self.__train_iter,
)
for k, v in self.opt.hypers[-1].items(): # noqa
logger.report_scalar(title=k + "_" + self.__id, series=k, value=v, iteration=self.__train_iter)
params = [
(name, values.clone().detach().cpu()) for (name, values) in self.model.named_parameters()
] # noqa
if self.__gradient_hist_helper.logger is not logger:
self.__gradient_hist_helper = WeightsGradientHistHelper(logger)
histograms = []
for (name, values) in params:
histograms.append(
dict(
title="model_weights_" + self.__id,
series="model_weights/" + name,
step=self.__train_iter,
hist_data=values,
)
)
self.__gradient_hist_helper.add_histograms(histograms)
except Exception:
pass
def after_epoch(self):
# noinspection PyBroadException
try:
super().after_epoch() # noqa
logger = PatchFastaiV2._main_task.get_logger()
for metric in self._valid_mets: # noqa
logger.report_scalar(
title="metrics_" + self.__id,
series="valid_" + metric.name,
value=metric.value,
iteration=self.__train_iter,
)
except Exception:
pass
def before_step(self):
# noinspection PyBroadException
try:
if hasattr(fastai.learner.Recorder, "before_step"):
super().before_step() # noqa
logger = PatchFastaiV2._main_task.get_logger()
gradients = [
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
] # noqa
if len(gradients) == 0:
return
def count_zeros(gradient):
n = gradient.data.data.cpu().numpy()
return n.size - np.count_nonzero(n)
gradient_stats = np.array(
[
(x.data.norm(), count_zeros(x), x.data.mean(), np.median(x.data), x.data.max(), x.data.min())
for x in gradients
]
)
# TODO: Check computation!
stats_report = dict(
avg_norm=np.mean(gradient_stats[:, 0]),
median_norm=np.median(gradient_stats[:, 0]),
max_norm=np.max(gradient_stats[:, 0]),
min_norm=np.min(gradient_stats[:, 0]),
num_zeros=gradient_stats[:, 1].sum(),
avg_gradient=gradient_stats[:, 2].mean(),
median_gradient=np.median(gradient_stats[:, 3]),
max_gradient=gradient_stats[:, 4].max(),
min_gradient=gradient_stats[:, 5].min(),
)
for name, val in stats_report.items():
if name != "num_zeros":
title = "model_stats_gradients_" + self.__id
else:
title = "model_stats_gradients_num_zeros_" + self.__id
logger.report_scalar(title=title, series=name, value=val, iteration=self.__train_iter)
except Exception:
pass
@staticmethod
def _insert_callbacks(original_fn, obj, *args, **kwargs):
obj.add_cb(PatchFastaiV2.PatchFastaiCallbacks)
return original_fn(obj, *args, **kwargs)

View File

@ -72,7 +72,7 @@ class IsTensorboardInit(object):
# noinspection PyProtectedMember
class WeightsGradientHistHelper(object):
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
self._logger = logger
self.logger = logger
self.report_freq = report_freq
self._histogram_granularity = histogram_granularity
self._histogram_update_freq_multiplier = histogram_update_freq_multiplier
@ -98,12 +98,24 @@ class WeightsGradientHistHelper(object):
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
return _cur_idx
def add_histogram(self, title, series, step, hist_data):
def add_histograms(self, histograms):
for index, histogram in enumerate(histograms):
self.add_histogram(
histogram.get("title"),
histogram.get("series"),
histogram.get("step"),
histogram.get("hist_data"),
increase_histogram_update_call_counter=(index == len(histograms) - 1),
)
def add_histogram(self, title, series, step, hist_data, increase_histogram_update_call_counter=True):
# only collect histogram every specific interval
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
offset = 1 if increase_histogram_update_call_counter else 0
if (self._histogram_update_call_counter + offset) % self.report_freq != 0 or (
self._histogram_update_call_counter + offset
) < self.report_freq - 1:
self._histogram_update_call_counter += offset
return
if isinstance(hist_data, dict):
pass
@ -112,21 +124,29 @@ class WeightsGradientHistHelper(object):
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# hist_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
hist_data = {"bucketLimit": hist_data[:, 0].tolist(), "bucket": hist_data[:, 2].tolist()}
else:
# assume we have to do the histogram on the data
hist_data = np.histogram(hist_data, bins=32)
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
hist_data = {"bucketLimit": hist_data[1].tolist(), "bucket": hist_data[0].tolist()}
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
self._add_histogram(
title=title,
series=series,
step=step,
hist_data=hist_data,
increase_histogram_update_call_counter=increase_histogram_update_call_counter,
)
def _add_histogram(self, title, series, step, hist_data):
def _add_histogram(self, title, series, step, hist_data, increase_histogram_update_call_counter=True):
# only collect histogram every specific interval
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
if increase_histogram_update_call_counter:
self._histogram_update_call_counter += 1
if (
self._histogram_update_call_counter % self.report_freq != 0
or self._histogram_update_call_counter < self.report_freq - 1
):
return
# generate forward matrix of the histograms
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
@ -166,7 +186,7 @@ class WeightsGradientHistHelper(object):
# only report histogram every specific interval, but do report the first few, so you know there are histograms
if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and
hist_iters.size % self._histogram_update_freq_multiplier != 0):
return None
return
# resample histograms on a unified bin axis +- epsilon
_epsilon = abs((minmax[1] - minmax[0])/float(self._hist_x_granularity))
@ -193,7 +213,7 @@ class WeightsGradientHistHelper(object):
skipy = max(1, int(yedges.size / 10))
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
self._logger.report_surface(
self.logger.report_surface(
title=title,
series=series,
iteration=0,

View File

@ -0,0 +1,33 @@
# ClearML - Fastai example code, automatic logging the model and scalars
#
import argparse
from clearml import Task
import fastai
try:
from fastai.vision import untar_data, URLs, ImageDataBunch, rand_pad, imagenet_stats, cnn_learner, models, accuracy
except ImportError:
raise ImportError("FastAI version %s imported, but this example is for FastAI v1." % fastai.__version__)
def main(epochs):
Task.init(project_name="examples", task_name="fastai v1")
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64, num_workers=0)
data.normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
accuracy(*learn.get_preds())
learn.fit_one_cycle(epochs, 0.01)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=3)
args = parser.parse_args()
main(args.epochs)

View File

@ -0,0 +1,2 @@
fastai < 2.0.0
clearml

View File

@ -0,0 +1,47 @@
# ClearML - Fastai with Tensorboard example code, automatic logging the model and Tensorboard outputs
#
import argparse
from clearml import Task
import fastai
try:
from fastai.vision import (
untar_data,
URLs,
ImageDataBunch,
rand_pad,
imagenet_stats,
cnn_learner,
models,
accuracy,
Path,
partial,
)
from fastai.callbacks.tensorboard import LearnerTensorboardWriter
except ImportError:
raise ImportError("FastAI version %s imported, but this example is for FastAI v1." % fastai.__version__)
def main(epochs):
Task.init(project_name="examples", task_name="fastai with tensorboard callback")
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64, num_workers=0)
data.normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
tboard_path = Path("data/tensorboard/project1")
learn.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=tboard_path, name="run0"))
accuracy(*learn.get_preds())
learn.fit_one_cycle(epochs, 0.01)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=3)
args = parser.parse_args()
main(args.epochs)

View File

@ -1,4 +1,4 @@
fastai
fastai < 2.0.0
tensorboard
tensorboardX
clearml

View File

@ -0,0 +1,45 @@
# ClearML - Fastai v2 example code, automatic logging the model and various scalars
#
import argparse
from clearml import Task
import fastai
try:
from fastai.vision.all import (
untar_data,
URLs,
get_image_files,
ImageDataLoaders,
Resize,
cnn_learner,
resnet34,
error_rate,
)
except ImportError:
raise ImportError("FastAI version %s imported, but this example is for FastAI v2." % fastai.__version__)
def label_func(f):
return f[0].isupper()
def main(epochs):
Task.init(project_name="examples", task_name="fastai v2")
path = untar_data(URLs.PETS)
files = get_image_files(path / "images")
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224), num_workers=0)
dls.show_batch()
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(epochs)
learn.show_results()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=3)
args = parser.parse_args()
main(args.epochs)

View File

@ -0,0 +1,2 @@
fastai >= 2.0.0
clearml

View File

@ -0,0 +1,46 @@
# ClearML - Fastai v2 with tensorboard callbacks example code, automatic logging the model and various scalars
#
import argparse
from clearml import Task
import fastai
try:
from fastai.vision.all import (
untar_data,
URLs,
get_image_files,
ImageDataLoaders,
Resize,
cnn_learner,
resnet34,
error_rate,
)
from fastai.callback.tensorboard import TensorBoardCallback
except ImportError:
raise ImportError("FastAI version %s imported, but this example is for FastAI v2." % fastai.__version__)
def label_func(f):
return f[0].isupper()
def main(epochs):
Task.init(project_name="examples", task_name="fastai v2 with tensorboard callback")
path = untar_data(URLs.PETS)
files = get_image_files(path / "images")
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224), num_workers=0)
dls.show_batch()
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(epochs, cbs=[TensorBoardCallback()])
learn.show_results()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=3)
args = parser.parse_args()
main(args.epochs)

View File

@ -0,0 +1,4 @@
fastai >= 2.0.0
tensorboard
tensorboardX
clearml