mirror of
https://github.com/clearml/clearml
synced 2025-04-10 15:35:51 +00:00
Support fastai v2 (#571)
This commit is contained in:
parent
16b009f1ff
commit
a8fbe51231
@ -2,50 +2,77 @@ import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from packaging import version
|
||||
|
||||
from . import _patched_call
|
||||
from .tensorflow_bind import WeightsGradientHistHelper
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...debugging.log import LoggerRoot
|
||||
|
||||
try:
|
||||
import fastai
|
||||
except ImportError:
|
||||
fastai = None
|
||||
|
||||
|
||||
class PatchFastai(object):
|
||||
__metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID
|
||||
__main_task = None
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
if fastai is None:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if version.parse(fastai.__version__) < version.parse("2.0.0"):
|
||||
PatchFastaiV1.update_current_task(task)
|
||||
PatchFastaiV1.patch_model_callback()
|
||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
|
||||
else:
|
||||
PatchFastaiV2.update_current_task(task)
|
||||
PatchFastaiV2.patch_model_callback()
|
||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class PatchFastaiV1(object):
|
||||
__metrics_names = {}
|
||||
__gradient_hist_helpers = {}
|
||||
_main_task = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchFastai.__main_task = task
|
||||
PatchFastai._patch_model_callback()
|
||||
PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback)
|
||||
PatchFastaiV1._main_task = task
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_callback():
|
||||
# if you have tensroboard, we assume you use TesnorboardLogger, which we catch, so no need to patch.
|
||||
def patch_model_callback():
|
||||
# if you have tensorboard, we assume you use TensorboardLogger, which we catch, so no need to patch.
|
||||
if "tensorboard" in sys.modules:
|
||||
return
|
||||
|
||||
if "fastai" in sys.modules:
|
||||
try:
|
||||
from fastai.basic_train import Recorder
|
||||
try:
|
||||
from fastai.basic_train import Recorder
|
||||
|
||||
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end)
|
||||
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end)
|
||||
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end)
|
||||
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(PatchFastai).debug(str(ex))
|
||||
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
|
||||
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
|
||||
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
|
||||
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastaiV1._on_train_begin)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(PatchFastaiV1).debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastai.__main_task:
|
||||
if not PatchFastaiV1._main_task:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
|
||||
PatchFastai.__metrics_names += recorder.metrics_names
|
||||
PatchFastaiV1.__metrics_names[id(recorder)] = (
|
||||
["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
|
||||
)
|
||||
PatchFastaiV1.__metrics_names[id(recorder)] += recorder.metrics_names
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -53,25 +80,26 @@ class PatchFastai(object):
|
||||
def _on_backward_end(original_fn, recorder, *args, **kwargs):
|
||||
def count_zeros(gradient):
|
||||
n = gradient.data.data.cpu().numpy()
|
||||
return n.size - n.count_nonzero()
|
||||
return n.size - np.count_nonzero(n)
|
||||
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
|
||||
if not PatchFastai.__main_task:
|
||||
if not PatchFastaiV1._main_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
gradients = [
|
||||
x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None
|
||||
]
|
||||
gradients = [x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None]
|
||||
if len(gradients) == 0:
|
||||
return
|
||||
|
||||
# TODO: Check computation!
|
||||
gradient_stats = np.array([
|
||||
(x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min())
|
||||
for x in gradients])
|
||||
gradient_stats = np.array(
|
||||
[
|
||||
(x.data.norm(), count_zeros(x), x.data.mean(), np.median(x.data), x.data.max(), x.data.min())
|
||||
for x in gradients
|
||||
]
|
||||
)
|
||||
stats_report = dict(
|
||||
avg_norm=np.mean(gradient_stats[:, 0]),
|
||||
median_norm=np.median(gradient_stats[:, 0]),
|
||||
@ -79,12 +107,12 @@ class PatchFastai(object):
|
||||
min_norm=np.min(gradient_stats[:, 0]),
|
||||
num_zeros=gradient_stats[:, 1].sum(),
|
||||
avg_gradient=gradient_stats[:, 2].mean(),
|
||||
median_gradient=gradient_stats[:, 3].median(),
|
||||
median_gradient=np.median(gradient_stats[:, 3]),
|
||||
max_gradient=gradient_stats[:, 4].max(),
|
||||
min_gradient=gradient_stats[:, 5].min(),
|
||||
)
|
||||
|
||||
logger = PatchFastai.__main_task.get_logger()
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
iteration = kwargs.get("iteration", 0)
|
||||
for name, val in stats_report.items():
|
||||
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
|
||||
@ -94,52 +122,193 @@ class PatchFastai(object):
|
||||
@staticmethod
|
||||
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastai.__main_task:
|
||||
if not PatchFastaiV1._main_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
logger = PatchFastai.__main_task.get_logger()
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
iteration = kwargs.get("iteration")
|
||||
for series, value in zip(
|
||||
PatchFastai.__metrics_names,
|
||||
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
||||
PatchFastaiV1.__metrics_names[id(recorder)],
|
||||
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
||||
):
|
||||
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
|
||||
PatchFastai.__main_task.flush()
|
||||
PatchFastaiV1._main_task.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastai.__main_task:
|
||||
if not PatchFastaiV1._main_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
|
||||
iteration = kwargs.get("iteration", 0)
|
||||
if iteration == 0 or not kwargs.get("train"):
|
||||
return
|
||||
|
||||
logger = PatchFastai.__main_task.get_logger()
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
logger.report_scalar(
|
||||
title="metrics",
|
||||
series="train_loss",
|
||||
value=kwargs.get("last_loss", 0),
|
||||
iteration=kwargs.get("iteration", 0)
|
||||
iteration=iteration,
|
||||
)
|
||||
gradient_hist_helper = WeightsGradientHistHelper(logger)
|
||||
iteration = kwargs.get("iteration")
|
||||
params = [
|
||||
(name, values.clone().detach().cpu())
|
||||
for (name, values) in recorder.model.named_parameters()
|
||||
]
|
||||
params = [(name, values.clone().detach().cpu()) for (name, values) in recorder.model.named_parameters()]
|
||||
if (
|
||||
id(recorder) not in PatchFastaiV1.__gradient_hist_helpers
|
||||
or PatchFastaiV1.__gradient_hist_helpers[id(recorder)].logger is not logger
|
||||
):
|
||||
PatchFastaiV1.__gradient_hist_helpers[id(recorder)] = WeightsGradientHistHelper(logger)
|
||||
histograms = []
|
||||
for (name, values) in params:
|
||||
gradient_hist_helper.add_histogram(
|
||||
title="model_weights",
|
||||
series="model_weights/" + name,
|
||||
step=iteration,
|
||||
hist_data=values,
|
||||
histograms.append(
|
||||
dict(title="model_weights", series="model_weights/" + name, step=iteration, hist_data=values)
|
||||
)
|
||||
PatchFastaiV1.__gradient_hist_helpers[id(recorder)].add_histograms(histograms)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class PatchFastaiV2(object):
|
||||
_main_task = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchFastaiV2._main_task = task
|
||||
|
||||
@staticmethod
|
||||
def patch_model_callback():
|
||||
if "tensorboard" in sys.modules:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
fastai.learner.Learner.fit = _patched_call(fastai.learner.Learner.fit, PatchFastaiV2._insert_callbacks)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from fastai.learner import Recorder
|
||||
|
||||
__patch_fastai_callbacks_base = Recorder
|
||||
except ImportError:
|
||||
__patch_fastai_callbacks_base = object
|
||||
|
||||
class PatchFastaiCallbacks(__patch_fastai_callbacks_base):
|
||||
__id = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["train_metrics"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__train_iter = 0
|
||||
|
||||
def noop(*_, **__):
|
||||
pass
|
||||
|
||||
self.logger = noop
|
||||
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
|
||||
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
|
||||
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger())
|
||||
|
||||
def after_batch(self):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
super().after_batch() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
if not self.training: # noqa
|
||||
return
|
||||
self.__train_iter += 1
|
||||
for metric in self._train_mets: # noqa
|
||||
logger.report_scalar(
|
||||
title="metrics_" + self.__id,
|
||||
series="train_" + metric.name,
|
||||
value=metric.value,
|
||||
iteration=self.__train_iter,
|
||||
)
|
||||
for k, v in self.opt.hypers[-1].items(): # noqa
|
||||
logger.report_scalar(title=k + "_" + self.__id, series=k, value=v, iteration=self.__train_iter)
|
||||
params = [
|
||||
(name, values.clone().detach().cpu()) for (name, values) in self.model.named_parameters()
|
||||
] # noqa
|
||||
if self.__gradient_hist_helper.logger is not logger:
|
||||
self.__gradient_hist_helper = WeightsGradientHistHelper(logger)
|
||||
histograms = []
|
||||
for (name, values) in params:
|
||||
histograms.append(
|
||||
dict(
|
||||
title="model_weights_" + self.__id,
|
||||
series="model_weights/" + name,
|
||||
step=self.__train_iter,
|
||||
hist_data=values,
|
||||
)
|
||||
)
|
||||
self.__gradient_hist_helper.add_histograms(histograms)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
super().after_epoch() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
for metric in self._valid_mets: # noqa
|
||||
logger.report_scalar(
|
||||
title="metrics_" + self.__id,
|
||||
series="valid_" + metric.name,
|
||||
value=metric.value,
|
||||
iteration=self.__train_iter,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if hasattr(fastai.learner.Recorder, "before_step"):
|
||||
super().before_step() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
gradients = [
|
||||
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
|
||||
] # noqa
|
||||
if len(gradients) == 0:
|
||||
return
|
||||
|
||||
def count_zeros(gradient):
|
||||
n = gradient.data.data.cpu().numpy()
|
||||
return n.size - np.count_nonzero(n)
|
||||
|
||||
gradient_stats = np.array(
|
||||
[
|
||||
(x.data.norm(), count_zeros(x), x.data.mean(), np.median(x.data), x.data.max(), x.data.min())
|
||||
for x in gradients
|
||||
]
|
||||
)
|
||||
# TODO: Check computation!
|
||||
stats_report = dict(
|
||||
avg_norm=np.mean(gradient_stats[:, 0]),
|
||||
median_norm=np.median(gradient_stats[:, 0]),
|
||||
max_norm=np.max(gradient_stats[:, 0]),
|
||||
min_norm=np.min(gradient_stats[:, 0]),
|
||||
num_zeros=gradient_stats[:, 1].sum(),
|
||||
avg_gradient=gradient_stats[:, 2].mean(),
|
||||
median_gradient=np.median(gradient_stats[:, 3]),
|
||||
max_gradient=gradient_stats[:, 4].max(),
|
||||
min_gradient=gradient_stats[:, 5].min(),
|
||||
)
|
||||
for name, val in stats_report.items():
|
||||
if name != "num_zeros":
|
||||
title = "model_stats_gradients_" + self.__id
|
||||
else:
|
||||
title = "model_stats_gradients_num_zeros_" + self.__id
|
||||
logger.report_scalar(title=title, series=name, value=val, iteration=self.__train_iter)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _insert_callbacks(original_fn, obj, *args, **kwargs):
|
||||
obj.add_cb(PatchFastaiV2.PatchFastaiCallbacks)
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
|
@ -72,7 +72,7 @@ class IsTensorboardInit(object):
|
||||
# noinspection PyProtectedMember
|
||||
class WeightsGradientHistHelper(object):
|
||||
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
|
||||
self._logger = logger
|
||||
self.logger = logger
|
||||
self.report_freq = report_freq
|
||||
self._histogram_granularity = histogram_granularity
|
||||
self._histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||
@ -98,12 +98,24 @@ class WeightsGradientHistHelper(object):
|
||||
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
|
||||
return _cur_idx
|
||||
|
||||
def add_histogram(self, title, series, step, hist_data):
|
||||
def add_histograms(self, histograms):
|
||||
for index, histogram in enumerate(histograms):
|
||||
self.add_histogram(
|
||||
histogram.get("title"),
|
||||
histogram.get("series"),
|
||||
histogram.get("step"),
|
||||
histogram.get("hist_data"),
|
||||
increase_histogram_update_call_counter=(index == len(histograms) - 1),
|
||||
)
|
||||
|
||||
def add_histogram(self, title, series, step, hist_data, increase_histogram_update_call_counter=True):
|
||||
# only collect histogram every specific interval
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
offset = 1 if increase_histogram_update_call_counter else 0
|
||||
if (self._histogram_update_call_counter + offset) % self.report_freq != 0 or (
|
||||
self._histogram_update_call_counter + offset
|
||||
) < self.report_freq - 1:
|
||||
self._histogram_update_call_counter += offset
|
||||
return
|
||||
|
||||
if isinstance(hist_data, dict):
|
||||
pass
|
||||
@ -112,21 +124,29 @@ class WeightsGradientHistHelper(object):
|
||||
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# hist_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||
hist_data = {"bucketLimit": hist_data[:, 0].tolist(), "bucket": hist_data[:, 2].tolist()}
|
||||
else:
|
||||
# assume we have to do the histogram on the data
|
||||
hist_data = np.histogram(hist_data, bins=32)
|
||||
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||
hist_data = {"bucketLimit": hist_data[1].tolist(), "bucket": hist_data[0].tolist()}
|
||||
|
||||
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
||||
self._add_histogram(
|
||||
title=title,
|
||||
series=series,
|
||||
step=step,
|
||||
hist_data=hist_data,
|
||||
increase_histogram_update_call_counter=increase_histogram_update_call_counter,
|
||||
)
|
||||
|
||||
def _add_histogram(self, title, series, step, hist_data):
|
||||
def _add_histogram(self, title, series, step, hist_data, increase_histogram_update_call_counter=True):
|
||||
# only collect histogram every specific interval
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
|
||||
if increase_histogram_update_call_counter:
|
||||
self._histogram_update_call_counter += 1
|
||||
if (
|
||||
self._histogram_update_call_counter % self.report_freq != 0
|
||||
or self._histogram_update_call_counter < self.report_freq - 1
|
||||
):
|
||||
return
|
||||
# generate forward matrix of the histograms
|
||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||
@ -166,7 +186,7 @@ class WeightsGradientHistHelper(object):
|
||||
# only report histogram every specific interval, but do report the first few, so you know there are histograms
|
||||
if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and
|
||||
hist_iters.size % self._histogram_update_freq_multiplier != 0):
|
||||
return None
|
||||
return
|
||||
|
||||
# resample histograms on a unified bin axis +- epsilon
|
||||
_epsilon = abs((minmax[1] - minmax[0])/float(self._hist_x_granularity))
|
||||
@ -193,7 +213,7 @@ class WeightsGradientHistHelper(object):
|
||||
skipy = max(1, int(yedges.size / 10))
|
||||
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
|
||||
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
|
||||
self._logger.report_surface(
|
||||
self.logger.report_surface(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=0,
|
||||
|
@ -0,0 +1,33 @@
|
||||
# ClearML - Fastai example code, automatic logging the model and scalars
|
||||
#
|
||||
import argparse
|
||||
|
||||
from clearml import Task
|
||||
|
||||
import fastai
|
||||
|
||||
try:
|
||||
from fastai.vision import untar_data, URLs, ImageDataBunch, rand_pad, imagenet_stats, cnn_learner, models, accuracy
|
||||
except ImportError:
|
||||
raise ImportError("FastAI version %s imported, but this example is for FastAI v1." % fastai.__version__)
|
||||
|
||||
|
||||
def main(epochs):
|
||||
Task.init(project_name="examples", task_name="fastai v1")
|
||||
|
||||
path = untar_data(URLs.MNIST_SAMPLE)
|
||||
|
||||
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64, num_workers=0)
|
||||
data.normalize(imagenet_stats)
|
||||
|
||||
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
|
||||
|
||||
accuracy(*learn.get_preds())
|
||||
learn.fit_one_cycle(epochs, 0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", default=3)
|
||||
args = parser.parse_args()
|
||||
main(args.epochs)
|
@ -0,0 +1,2 @@
|
||||
fastai < 2.0.0
|
||||
clearml
|
@ -0,0 +1,47 @@
|
||||
# ClearML - Fastai with Tensorboard example code, automatic logging the model and Tensorboard outputs
|
||||
#
|
||||
import argparse
|
||||
|
||||
from clearml import Task
|
||||
|
||||
import fastai
|
||||
|
||||
try:
|
||||
from fastai.vision import (
|
||||
untar_data,
|
||||
URLs,
|
||||
ImageDataBunch,
|
||||
rand_pad,
|
||||
imagenet_stats,
|
||||
cnn_learner,
|
||||
models,
|
||||
accuracy,
|
||||
Path,
|
||||
partial,
|
||||
)
|
||||
from fastai.callbacks.tensorboard import LearnerTensorboardWriter
|
||||
except ImportError:
|
||||
raise ImportError("FastAI version %s imported, but this example is for FastAI v1." % fastai.__version__)
|
||||
|
||||
|
||||
def main(epochs):
|
||||
Task.init(project_name="examples", task_name="fastai with tensorboard callback")
|
||||
|
||||
path = untar_data(URLs.MNIST_SAMPLE)
|
||||
|
||||
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64, num_workers=0)
|
||||
data.normalize(imagenet_stats)
|
||||
|
||||
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
|
||||
tboard_path = Path("data/tensorboard/project1")
|
||||
learn.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=tboard_path, name="run0"))
|
||||
|
||||
accuracy(*learn.get_preds())
|
||||
learn.fit_one_cycle(epochs, 0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", default=3)
|
||||
args = parser.parse_args()
|
||||
main(args.epochs)
|
@ -1,4 +1,4 @@
|
||||
fastai
|
||||
fastai < 2.0.0
|
||||
tensorboard
|
||||
tensorboardX
|
||||
clearml
|
@ -0,0 +1,45 @@
|
||||
# ClearML - Fastai v2 example code, automatic logging the model and various scalars
|
||||
#
|
||||
import argparse
|
||||
|
||||
from clearml import Task
|
||||
|
||||
import fastai
|
||||
|
||||
try:
|
||||
from fastai.vision.all import (
|
||||
untar_data,
|
||||
URLs,
|
||||
get_image_files,
|
||||
ImageDataLoaders,
|
||||
Resize,
|
||||
cnn_learner,
|
||||
resnet34,
|
||||
error_rate,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("FastAI version %s imported, but this example is for FastAI v2." % fastai.__version__)
|
||||
|
||||
|
||||
def label_func(f):
|
||||
return f[0].isupper()
|
||||
|
||||
|
||||
def main(epochs):
|
||||
Task.init(project_name="examples", task_name="fastai v2")
|
||||
|
||||
path = untar_data(URLs.PETS)
|
||||
files = get_image_files(path / "images")
|
||||
|
||||
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224), num_workers=0)
|
||||
dls.show_batch()
|
||||
learn = cnn_learner(dls, resnet34, metrics=error_rate)
|
||||
learn.fine_tune(epochs)
|
||||
learn.show_results()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", default=3)
|
||||
args = parser.parse_args()
|
||||
main(args.epochs)
|
@ -0,0 +1,2 @@
|
||||
fastai >= 2.0.0
|
||||
clearml
|
@ -0,0 +1,46 @@
|
||||
# ClearML - Fastai v2 with tensorboard callbacks example code, automatic logging the model and various scalars
|
||||
#
|
||||
import argparse
|
||||
|
||||
from clearml import Task
|
||||
|
||||
import fastai
|
||||
|
||||
try:
|
||||
from fastai.vision.all import (
|
||||
untar_data,
|
||||
URLs,
|
||||
get_image_files,
|
||||
ImageDataLoaders,
|
||||
Resize,
|
||||
cnn_learner,
|
||||
resnet34,
|
||||
error_rate,
|
||||
)
|
||||
from fastai.callback.tensorboard import TensorBoardCallback
|
||||
except ImportError:
|
||||
raise ImportError("FastAI version %s imported, but this example is for FastAI v2." % fastai.__version__)
|
||||
|
||||
|
||||
def label_func(f):
|
||||
return f[0].isupper()
|
||||
|
||||
|
||||
def main(epochs):
|
||||
Task.init(project_name="examples", task_name="fastai v2 with tensorboard callback")
|
||||
|
||||
path = untar_data(URLs.PETS)
|
||||
files = get_image_files(path / "images")
|
||||
|
||||
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224), num_workers=0)
|
||||
dls.show_batch()
|
||||
learn = cnn_learner(dls, resnet34, metrics=error_rate)
|
||||
learn.fine_tune(epochs, cbs=[TensorBoardCallback()])
|
||||
learn.show_results()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", default=3)
|
||||
args = parser.parse_args()
|
||||
main(args.epochs)
|
@ -0,0 +1,4 @@
|
||||
fastai >= 2.0.0
|
||||
tensorboard
|
||||
tensorboardX
|
||||
clearml
|
Loading…
Reference in New Issue
Block a user