clearml/trains/binding/frameworks/fastai_bind.py

142 lines
5.3 KiB
Python
Raw Normal View History

2020-07-10 22:29:29 +00:00
import sys
import numpy as np
from . import _patched_call
from .tensorflow_bind import WeightsGradientHistHelper
from ..import_bind import PostImportHookPatching
from ...debugging.log import LoggerRoot
class PatchFastai(object):
2020-07-30 11:54:37 +00:00
__metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID
2020-07-10 22:29:29 +00:00
__main_task = None
@staticmethod
2020-07-30 11:54:37 +00:00
def update_current_task(task, **_):
2020-07-10 22:29:29 +00:00
PatchFastai.__main_task = task
PatchFastai._patch_model_callback()
2020-07-30 11:54:37 +00:00
PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback)
2020-07-10 22:29:29 +00:00
@staticmethod
def _patch_model_callback():
if "fastai" in sys.modules:
try:
from fastai.basic_train import Recorder
2020-07-30 11:54:37 +00:00
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end)
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end)
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end)
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin)
2020-07-10 22:29:29 +00:00
except ImportError:
pass
except Exception as ex:
LoggerRoot.get_base_logger(PatchFastai).debug(str(ex))
@staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
2020-07-30 11:54:37 +00:00
if not PatchFastai.__main_task:
return
# noinspection PyBroadException
try:
PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
PatchFastai.__metrics_names += recorder.metrics_names
except Exception:
2020-07-30 11:54:37 +00:00
pass
2020-07-10 22:29:29 +00:00
@staticmethod
def _on_backward_end(original_fn, recorder, *args, **kwargs):
2020-07-30 11:54:37 +00:00
def count_zeros(gradient):
n = gradient.data.data.cpu().numpy()
return n.size - n.count_nonzero()
2020-07-10 22:29:29 +00:00
original_fn(recorder, *args, **kwargs)
2020-07-30 11:54:37 +00:00
if not PatchFastai.__main_task:
2020-07-10 22:29:29 +00:00
return
2020-07-30 11:54:37 +00:00
# noinspection PyBroadException
try:
gradients = [
x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None
]
if len(gradients) == 0:
return
# TODO: Check computation!
gradient_stats = np.array([
(x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min())
for x in gradients])
stats_report = dict(
avg_norm=np.mean(gradient_stats[:, 0]),
median_norm=np.median(gradient_stats[:, 0]),
max_norm=np.max(gradient_stats[:, 0]),
min_norm=np.min(gradient_stats[:, 0]),
num_zeros=gradient_stats[:, 1].sum(),
avg_gradient=gradient_stats[:, 2].mean(),
median_gradient=gradient_stats[:, 3].median(),
max_gradient=gradient_stats[:, 4].max(),
min_gradient=gradient_stats[:, 5].min(),
)
logger = PatchFastai.__main_task.get_logger()
iteration = kwargs.get("iteration", 0)
for name, val in stats_report.items():
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
except Exception:
2020-07-30 11:54:37 +00:00
pass
2020-07-10 22:29:29 +00:00
@staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
2020-07-30 11:54:37 +00:00
if not PatchFastai.__main_task:
return
# noinspection PyBroadException
try:
logger = PatchFastai.__main_task.get_logger()
iteration = kwargs.get("iteration")
for series, value in zip(
PatchFastai.__metrics_names,
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
):
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
PatchFastai.__main_task.flush()
except Exception:
pass
2020-07-10 22:29:29 +00:00
@staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
2020-07-30 11:54:37 +00:00
if not PatchFastai.__main_task:
2020-07-10 22:29:29 +00:00
return
2020-07-30 11:54:37 +00:00
# noinspection PyBroadException
try:
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
return
logger = PatchFastai.__main_task.get_logger()
logger.report_scalar(
title="metrics",
series="train_loss",
value=kwargs.get("last_loss", 0),
iteration=kwargs.get("iteration", 0)
2020-07-10 22:29:29 +00:00
)
2020-07-30 11:54:37 +00:00
gradient_hist_helper = WeightsGradientHistHelper(logger)
iteration = kwargs.get("iteration")
params = [
(name, values.clone().detach().cpu())
for (name, values) in recorder.model.named_parameters()
]
for (name, values) in params:
gradient_hist_helper.add_histogram(
title="model_weights",
series="model_weights/" + name,
step=iteration,
hist_data=values,
)
except Exception:
pass