From 093477cb35de55d7504722df03626a48298d9c23 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 30 Jul 2020 14:54:37 +0300 Subject: [PATCH] Refactor fastai bind --- trains/binding/frameworks/fastai_bind.py | 181 ++++++++++++----------- 1 file changed, 94 insertions(+), 87 deletions(-) diff --git a/trains/binding/frameworks/fastai_bind.py b/trains/binding/frameworks/fastai_bind.py index 389304a6..1f51ee84 100644 --- a/trains/binding/frameworks/fastai_bind.py +++ b/trains/binding/frameworks/fastai_bind.py @@ -1,4 +1,3 @@ -import statistics import sys import numpy as np @@ -10,16 +9,14 @@ from ...debugging.log import LoggerRoot class PatchFastai(object): - __metrics_names = None + __metrics_names = None # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID __main_task = None @staticmethod - def update_current_task(task, **kwargs): + def update_current_task(task, **_): PatchFastai.__main_task = task PatchFastai._patch_model_callback() - PostImportHookPatching.add_on_import( - "fastai", PatchFastai._patch_model_callback - ) + PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback) @staticmethod def _patch_model_callback(): @@ -27,19 +24,10 @@ class PatchFastai(object): try: from fastai.basic_train import Recorder - Recorder.on_batch_end = _patched_call( - Recorder.on_batch_end, PatchFastai._on_batch_end - ) - Recorder.on_backward_end = _patched_call( - Recorder.on_backward_end, PatchFastai._on_backward_end - ) - Recorder.on_epoch_end = _patched_call( - Recorder.on_epoch_end, PatchFastai._on_epoch_end - ) - Recorder.on_train_begin = _patched_call( - Recorder.on_train_begin, PatchFastai._on_train_begin - ) - + Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end) + Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end) + Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end) + Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin) except ImportError: pass except Exception as ex: @@ -48,87 +36,106 @@ class PatchFastai(object): @staticmethod def _on_train_begin(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - PatchFastai.__metrics_names = ( - ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] - ) - PatchFastai.__metrics_names += recorder.metrics_names + if not PatchFastai.__main_task: + return + # noinspection PyBroadException + try: + PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] + PatchFastai.__metrics_names += recorder.metrics_names + except Exception as ex: + pass @staticmethod def _on_backward_end(original_fn, recorder, *args, **kwargs): - def report_model_stats(series, value): - logger.report_scalar("model_stats_gradients", series, value, iteration) + def count_zeros(gradient): + n = gradient.data.data.cpu().numpy() + return n.size - n.count_nonzero() original_fn(recorder, *args, **kwargs) - gradients = [ - x.grad.clone().detach().cpu() - for x in recorder.learn.model.parameters() - if x.grad is not None - ] - if len(gradients) == 0: + + if not PatchFastai.__main_task: return - iteration = kwargs.get("iteration") - norms = [x.data.norm() for x in gradients] - logger = PatchFastai.__main_task.get_logger() - for name, val in zip( - [ - "avg_norm", - "median_norm", - "max_norm", - "min_norm", - "num_zeros", - "avg_gradient", - "median_gradient", - "max_gradient", - "min_gradient", - ], - [ - sum(norms) / len(gradients), - statistics.median(norms), - max(norms), - min(norms), - sum( - (np.asarray(x) == 0.0).sum() - for x in [x.data.data.cpu().numpy() for x in gradients] - ), - sum(x.data.mean() for x in gradients) / len(gradients), - statistics.median(x.data.median() for x in gradients), - max(x.data.max() for x in gradients), - min(x.data.min() for x in gradients), - ], - ): - report_model_stats(name, val) + + # noinspection PyBroadException + try: + gradients = [ + x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None + ] + if len(gradients) == 0: + return + + # TODO: Check computation! + gradient_stats = np.array([ + (x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min()) + for x in gradients]) + stats_report = dict( + avg_norm=np.mean(gradient_stats[:, 0]), + median_norm=np.median(gradient_stats[:, 0]), + max_norm=np.max(gradient_stats[:, 0]), + min_norm=np.min(gradient_stats[:, 0]), + num_zeros=gradient_stats[:, 1].sum(), + avg_gradient=gradient_stats[:, 2].mean(), + median_gradient=gradient_stats[:, 3].median(), + max_gradient=gradient_stats[:, 4].max(), + min_gradient=gradient_stats[:, 5].min(), + ) + + logger = PatchFastai.__main_task.get_logger() + iteration = kwargs.get("iteration", 0) + for name, val in stats_report.items(): + logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration) + except Exception as ex: + pass @staticmethod def _on_epoch_end(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - logger = PatchFastai.__main_task.get_logger() - iteration = kwargs.get("iteration") - for series, value in zip( - PatchFastai.__metrics_names, - [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), - ): - logger.report_scalar("metrics", series, value, iteration) - PatchFastai.__main_task.flush() + if not PatchFastai.__main_task: + return + + # noinspection PyBroadException + try: + logger = PatchFastai.__main_task.get_logger() + iteration = kwargs.get("iteration") + for series, value in zip( + PatchFastai.__metrics_names, + [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), + ): + logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration) + PatchFastai.__main_task.flush() + except Exception: + pass @staticmethod def _on_batch_end(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - if kwargs.get("iteration") == 0 or not kwargs.get("train"): + if not PatchFastai.__main_task: return - logger = PatchFastai.__main_task.get_logger() - logger.report_scalar( - "metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration") - ) - gradient_hist_helper = WeightsGradientHistHelper(logger) - iteration = kwargs.get("iteration") - params = [ - (name, values.clone().detach().cpu()) - for (name, values) in recorder.model.named_parameters() - ] - for (name, values) in params: - gradient_hist_helper.add_histogram( - title="model_weights", - series="model_weights/" + name, - step=iteration, - hist_data=values, + + # noinspection PyBroadException + try: + if kwargs.get("iteration") == 0 or not kwargs.get("train"): + return + + logger = PatchFastai.__main_task.get_logger() + logger.report_scalar( + title="metrics", + series="train_loss", + value=kwargs.get("last_loss", 0), + iteration=kwargs.get("iteration", 0) ) + gradient_hist_helper = WeightsGradientHistHelper(logger) + iteration = kwargs.get("iteration") + params = [ + (name, values.clone().detach().cpu()) + for (name, values) in recorder.model.named_parameters() + ] + for (name, values) in params: + gradient_hist_helper.add_histogram( + title="model_weights", + series="model_weights/" + name, + step=iteration, + hist_data=values, + ) + except Exception: + pass