mirror of
https://github.com/clearml/clearml
synced 2025-06-04 03:47:57 +00:00
Refactor fastai bind
This commit is contained in:
parent
00ccadf591
commit
093477cb35
@ -1,4 +1,3 @@
|
|||||||
import statistics
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,16 +9,14 @@ from ...debugging.log import LoggerRoot
|
|||||||
|
|
||||||
|
|
||||||
class PatchFastai(object):
|
class PatchFastai(object):
|
||||||
__metrics_names = None
|
__metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID
|
||||||
__main_task = None
|
__main_task = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **_):
|
||||||
PatchFastai.__main_task = task
|
PatchFastai.__main_task = task
|
||||||
PatchFastai._patch_model_callback()
|
PatchFastai._patch_model_callback()
|
||||||
PostImportHookPatching.add_on_import(
|
PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback)
|
||||||
"fastai", PatchFastai._patch_model_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_callback():
|
def _patch_model_callback():
|
||||||
@ -27,19 +24,10 @@ class PatchFastai(object):
|
|||||||
try:
|
try:
|
||||||
from fastai.basic_train import Recorder
|
from fastai.basic_train import Recorder
|
||||||
|
|
||||||
Recorder.on_batch_end = _patched_call(
|
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end)
|
||||||
Recorder.on_batch_end, PatchFastai._on_batch_end
|
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end)
|
||||||
)
|
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end)
|
||||||
Recorder.on_backward_end = _patched_call(
|
Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin)
|
||||||
Recorder.on_backward_end, PatchFastai._on_backward_end
|
|
||||||
)
|
|
||||||
Recorder.on_epoch_end = _patched_call(
|
|
||||||
Recorder.on_epoch_end, PatchFastai._on_epoch_end
|
|
||||||
)
|
|
||||||
Recorder.on_train_begin = _patched_call(
|
|
||||||
Recorder.on_train_begin, PatchFastai._on_train_begin
|
|
||||||
)
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -48,76 +36,93 @@ class PatchFastai(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
PatchFastai.__metrics_names = (
|
if not PatchFastai.__main_task:
|
||||||
["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
|
return
|
||||||
)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"]
|
||||||
PatchFastai.__metrics_names += recorder.metrics_names
|
PatchFastai.__metrics_names += recorder.metrics_names
|
||||||
|
except Exception as ex:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_backward_end(original_fn, recorder, *args, **kwargs):
|
def _on_backward_end(original_fn, recorder, *args, **kwargs):
|
||||||
def report_model_stats(series, value):
|
def count_zeros(gradient):
|
||||||
logger.report_scalar("model_stats_gradients", series, value, iteration)
|
n = gradient.data.data.cpu().numpy()
|
||||||
|
return n.size - n.count_nonzero()
|
||||||
|
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
|
|
||||||
|
if not PatchFastai.__main_task:
|
||||||
|
return
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
gradients = [
|
gradients = [
|
||||||
x.grad.clone().detach().cpu()
|
x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None
|
||||||
for x in recorder.learn.model.parameters()
|
|
||||||
if x.grad is not None
|
|
||||||
]
|
]
|
||||||
if len(gradients) == 0:
|
if len(gradients) == 0:
|
||||||
return
|
return
|
||||||
iteration = kwargs.get("iteration")
|
|
||||||
norms = [x.data.norm() for x in gradients]
|
# TODO: Check computation!
|
||||||
|
gradient_stats = np.array([
|
||||||
|
(x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min())
|
||||||
|
for x in gradients])
|
||||||
|
stats_report = dict(
|
||||||
|
avg_norm=np.mean(gradient_stats[:, 0]),
|
||||||
|
median_norm=np.median(gradient_stats[:, 0]),
|
||||||
|
max_norm=np.max(gradient_stats[:, 0]),
|
||||||
|
min_norm=np.min(gradient_stats[:, 0]),
|
||||||
|
num_zeros=gradient_stats[:, 1].sum(),
|
||||||
|
avg_gradient=gradient_stats[:, 2].mean(),
|
||||||
|
median_gradient=gradient_stats[:, 3].median(),
|
||||||
|
max_gradient=gradient_stats[:, 4].max(),
|
||||||
|
min_gradient=gradient_stats[:, 5].min(),
|
||||||
|
)
|
||||||
|
|
||||||
logger = PatchFastai.__main_task.get_logger()
|
logger = PatchFastai.__main_task.get_logger()
|
||||||
for name, val in zip(
|
iteration = kwargs.get("iteration", 0)
|
||||||
[
|
for name, val in stats_report.items():
|
||||||
"avg_norm",
|
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
|
||||||
"median_norm",
|
except Exception as ex:
|
||||||
"max_norm",
|
pass
|
||||||
"min_norm",
|
|
||||||
"num_zeros",
|
|
||||||
"avg_gradient",
|
|
||||||
"median_gradient",
|
|
||||||
"max_gradient",
|
|
||||||
"min_gradient",
|
|
||||||
],
|
|
||||||
[
|
|
||||||
sum(norms) / len(gradients),
|
|
||||||
statistics.median(norms),
|
|
||||||
max(norms),
|
|
||||||
min(norms),
|
|
||||||
sum(
|
|
||||||
(np.asarray(x) == 0.0).sum()
|
|
||||||
for x in [x.data.data.cpu().numpy() for x in gradients]
|
|
||||||
),
|
|
||||||
sum(x.data.mean() for x in gradients) / len(gradients),
|
|
||||||
statistics.median(x.data.median() for x in gradients),
|
|
||||||
max(x.data.max() for x in gradients),
|
|
||||||
min(x.data.min() for x in gradients),
|
|
||||||
],
|
|
||||||
):
|
|
||||||
report_model_stats(name, val)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
|
if not PatchFastai.__main_task:
|
||||||
|
return
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
logger = PatchFastai.__main_task.get_logger()
|
logger = PatchFastai.__main_task.get_logger()
|
||||||
iteration = kwargs.get("iteration")
|
iteration = kwargs.get("iteration")
|
||||||
for series, value in zip(
|
for series, value in zip(
|
||||||
PatchFastai.__metrics_names,
|
PatchFastai.__metrics_names,
|
||||||
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
||||||
):
|
):
|
||||||
logger.report_scalar("metrics", series, value, iteration)
|
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
|
||||||
PatchFastai.__main_task.flush()
|
PatchFastai.__main_task.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
|
if not PatchFastai.__main_task:
|
||||||
|
return
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
|
if kwargs.get("iteration") == 0 or not kwargs.get("train"):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger = PatchFastai.__main_task.get_logger()
|
logger = PatchFastai.__main_task.get_logger()
|
||||||
logger.report_scalar(
|
logger.report_scalar(
|
||||||
"metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration")
|
title="metrics",
|
||||||
|
series="train_loss",
|
||||||
|
value=kwargs.get("last_loss", 0),
|
||||||
|
iteration=kwargs.get("iteration", 0)
|
||||||
)
|
)
|
||||||
gradient_hist_helper = WeightsGradientHistHelper(logger)
|
gradient_hist_helper = WeightsGradientHistHelper(logger)
|
||||||
iteration = kwargs.get("iteration")
|
iteration = kwargs.get("iteration")
|
||||||
@ -132,3 +137,5 @@ class PatchFastai(object):
|
|||||||
step=iteration,
|
step=iteration,
|
||||||
hist_data=values,
|
hist_data=values,
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user