Refactor fastai bind

This commit is contained in:
allegroai 2020-07-30 14:54:37 +03:00
parent 00ccadf591
commit 093477cb35

View File

@ -1,4 +1,3 @@
import statistics
import sys import sys
import numpy as np import numpy as np
@ -10,16 +9,14 @@ from ...debugging.log import LoggerRoot
class PatchFastai(object): class PatchFastai(object):
__metrics_names = None __metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID
__main_task = None __main_task = None
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **_):
PatchFastai.__main_task = task PatchFastai.__main_task = task
PatchFastai._patch_model_callback() PatchFastai._patch_model_callback()
PostImportHookPatching.add_on_import( PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback)
"fastai", PatchFastai._patch_model_callback
)
@staticmethod @staticmethod
def _patch_model_callback(): def _patch_model_callback():
@ -27,19 +24,10 @@ class PatchFastai(object):
try: try:
from fastai.basic_train import Recorder from fastai.basic_train import Recorder
Recorder.on_batch_end = _patched_call( Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end)
Recorder.on_batch_end, PatchFastai._on_batch_end Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end)
) Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end)
Recorder.on_backward_end = _patched_call( Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin)
Recorder.on_backward_end, PatchFastai._on_backward_end
)
Recorder.on_epoch_end = _patched_call(
Recorder.on_epoch_end, PatchFastai._on_epoch_end
)
Recorder.on_train_begin = _patched_call(
Recorder.on_train_begin, PatchFastai._on_train_begin
)
except ImportError: except ImportError:
pass pass
except Exception as ex: except Exception as ex:
@ -48,87 +36,106 @@ class PatchFastai(object):
@staticmethod @staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs): def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
PatchFastai.__metrics_names = ( if not PatchFastai.__main_task:
["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] return
) # noinspection PyBroadException
PatchFastai.__metrics_names += recorder.metrics_names try:
PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
PatchFastai.__metrics_names += recorder.metrics_names
except Exception as ex:
pass
@staticmethod @staticmethod
def _on_backward_end(original_fn, recorder, *args, **kwargs): def _on_backward_end(original_fn, recorder, *args, **kwargs):
def report_model_stats(series, value): def count_zeros(gradient):
logger.report_scalar("model_stats_gradients", series, value, iteration) n = gradient.data.data.cpu().numpy()
return n.size - n.count_nonzero()
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
gradients = [
x.grad.clone().detach().cpu() if not PatchFastai.__main_task:
for x in recorder.learn.model.parameters()
if x.grad is not None
]
if len(gradients) == 0:
return return
iteration = kwargs.get("iteration")
norms = [x.data.norm() for x in gradients] # noinspection PyBroadException
logger = PatchFastai.__main_task.get_logger() try:
for name, val in zip( gradients = [
[ x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None
"avg_norm", ]
"median_norm", if len(gradients) == 0:
"max_norm", return
"min_norm",
"num_zeros", # TODO: Check computation!
"avg_gradient", gradient_stats = np.array([
"median_gradient", (x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min())
"max_gradient", for x in gradients])
"min_gradient", stats_report = dict(
], avg_norm=np.mean(gradient_stats[:, 0]),
[ median_norm=np.median(gradient_stats[:, 0]),
sum(norms) / len(gradients), max_norm=np.max(gradient_stats[:, 0]),
statistics.median(norms), min_norm=np.min(gradient_stats[:, 0]),
max(norms), num_zeros=gradient_stats[:, 1].sum(),
min(norms), avg_gradient=gradient_stats[:, 2].mean(),
sum( median_gradient=gradient_stats[:, 3].median(),
(np.asarray(x) == 0.0).sum() max_gradient=gradient_stats[:, 4].max(),
for x in [x.data.data.cpu().numpy() for x in gradients] min_gradient=gradient_stats[:, 5].min(),
), )
sum(x.data.mean() for x in gradients) / len(gradients),
statistics.median(x.data.median() for x in gradients), logger = PatchFastai.__main_task.get_logger()
max(x.data.max() for x in gradients), iteration = kwargs.get("iteration", 0)
min(x.data.min() for x in gradients), for name, val in stats_report.items():
], logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
): except Exception as ex:
report_model_stats(name, val) pass
@staticmethod @staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs): def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
logger = PatchFastai.__main_task.get_logger() if not PatchFastai.__main_task:
iteration = kwargs.get("iteration") return
for series, value in zip(
PatchFastai.__metrics_names, # noinspection PyBroadException
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), try:
): logger = PatchFastai.__main_task.get_logger()
logger.report_scalar("metrics", series, value, iteration) iteration = kwargs.get("iteration")
PatchFastai.__main_task.flush() for series, value in zip(
PatchFastai.__metrics_names,
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
):
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
PatchFastai.__main_task.flush()
except Exception:
pass
@staticmethod @staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs): def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
if kwargs.get("iteration") == 0 or not kwargs.get("train"): if not PatchFastai.__main_task:
return return
logger = PatchFastai.__main_task.get_logger()
logger.report_scalar( # noinspection PyBroadException
"metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration") try:
) if kwargs.get("iteration") == 0 or not kwargs.get("train"):
gradient_hist_helper = WeightsGradientHistHelper(logger) return
iteration = kwargs.get("iteration")
params = [ logger = PatchFastai.__main_task.get_logger()
(name, values.clone().detach().cpu()) logger.report_scalar(
for (name, values) in recorder.model.named_parameters() title="metrics",
] series="train_loss",
for (name, values) in params: value=kwargs.get("last_loss", 0),
gradient_hist_helper.add_histogram( iteration=kwargs.get("iteration", 0)
title="model_weights",
series="model_weights/" + name,
step=iteration,
hist_data=values,
) )
gradient_hist_helper = WeightsGradientHistHelper(logger)
iteration = kwargs.get("iteration")
params = [
(name, values.clone().detach().cpu())
for (name, values) in recorder.model.named_parameters()
]
for (name, values) in params:
gradient_hist_helper.add_histogram(
title="model_weights",
series="model_weights/" + name,
step=iteration,
hist_data=values,
)
except Exception:
pass