clearml/trains/binding/frameworks/fastai_bind.py

135 lines
4.7 KiB
Python
Raw Normal View History

2020-07-10 22:29:29 +00:00
import statistics
import sys
import numpy as np
from . import _patched_call
from .tensorflow_bind import WeightsGradientHistHelper
from ..import_bind import PostImportHookPatching
from ...debugging.log import LoggerRoot
class PatchFastai(object):
__metrics_names = None
__main_task = None
@staticmethod
def update_current_task(task, **kwargs):
PatchFastai.__main_task = task
PatchFastai._patch_model_callback()
PostImportHookPatching.add_on_import(
"fastai", PatchFastai._patch_model_callback
)
@staticmethod
def _patch_model_callback():
if "fastai" in sys.modules:
try:
from fastai.basic_train import Recorder
Recorder.on_batch_end = _patched_call(
Recorder.on_batch_end, PatchFastai._on_batch_end
)
Recorder.on_backward_end = _patched_call(
Recorder.on_backward_end, PatchFastai._on_backward_end
)
Recorder.on_epoch_end = _patched_call(
Recorder.on_epoch_end, PatchFastai._on_epoch_end
)
Recorder.on_train_begin = _patched_call(
Recorder.on_train_begin, PatchFastai._on_train_begin
)
except ImportError:
pass
except Exception as ex:
LoggerRoot.get_base_logger(PatchFastai).debug(str(ex))
@staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
PatchFastai.__metrics_names = (
["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
)
PatchFastai.__metrics_names += recorder.metrics_names
@staticmethod
def _on_backward_end(original_fn, recorder, *args, **kwargs):
def report_model_stats(series, value):
logger.report_scalar("model_stats_gradients", series, value, iteration)
original_fn(recorder, *args, **kwargs)
gradients = [
x.grad.clone().detach().cpu()
for x in recorder.learn.model.parameters()
if x.grad is not None
]
if len(gradients) == 0:
return
iteration = kwargs.get("iteration")
norms = [x.data.norm() for x in gradients]
logger = PatchFastai.__main_task.get_logger()
for name, val in zip(
[
"avg_norm",
"median_norm",
"max_norm",
"min_norm",
"num_zeros",
"avg_gradient",
"median_gradient",
"max_gradient",
"min_gradient",
],
[
sum(norms) / len(gradients),
statistics.median(norms),
max(norms),
min(norms),
sum(
(np.asarray(x) == 0.0).sum()
for x in [x.data.data.cpu().numpy() for x in gradients]
),
sum(x.data.mean() for x in gradients) / len(gradients),
statistics.median(x.data.median() for x in gradients),
max(x.data.max() for x in gradients),
min(x.data.min() for x in gradients),
],
):
report_model_stats(name, val)
@staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
logger = PatchFastai.__main_task.get_logger()
iteration = kwargs.get("iteration")
for series, value in zip(
PatchFastai.__metrics_names,
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
):
logger.report_scalar("metrics", series, value, iteration)
PatchFastai.__main_task.flush()
@staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
return
logger = PatchFastai.__main_task.get_logger()
logger.report_scalar(
"metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration")
)
gradient_hist_helper = WeightsGradientHistHelper(logger)
iteration = kwargs.get("iteration")
params = [
(name, values.clone().detach().cpu())
for (name, values) in recorder.model.named_parameters()
]
for (name, values) in params:
gradient_hist_helper.add_histogram(
title="model_weights",
series="model_weights/" + name,
step=iteration,
hist_data=values,
)