diff --git a/trains/binding/frameworks/fastai_bind.py b/trains/binding/frameworks/fastai_bind.py new file mode 100644 index 00000000..389304a6 --- /dev/null +++ b/trains/binding/frameworks/fastai_bind.py @@ -0,0 +1,134 @@ +import statistics +import sys + +import numpy as np + +from . import _patched_call +from .tensorflow_bind import WeightsGradientHistHelper +from ..import_bind import PostImportHookPatching +from ...debugging.log import LoggerRoot + + +class PatchFastai(object): + __metrics_names = None + __main_task = None + + @staticmethod + def update_current_task(task, **kwargs): + PatchFastai.__main_task = task + PatchFastai._patch_model_callback() + PostImportHookPatching.add_on_import( + "fastai", PatchFastai._patch_model_callback + ) + + @staticmethod + def _patch_model_callback(): + if "fastai" in sys.modules: + try: + from fastai.basic_train import Recorder + + Recorder.on_batch_end = _patched_call( + Recorder.on_batch_end, PatchFastai._on_batch_end + ) + Recorder.on_backward_end = _patched_call( + Recorder.on_backward_end, PatchFastai._on_backward_end + ) + Recorder.on_epoch_end = _patched_call( + Recorder.on_epoch_end, PatchFastai._on_epoch_end + ) + Recorder.on_train_begin = _patched_call( + Recorder.on_train_begin, PatchFastai._on_train_begin + ) + + except ImportError: + pass + except Exception as ex: + LoggerRoot.get_base_logger(PatchFastai).debug(str(ex)) + + @staticmethod + def _on_train_begin(original_fn, recorder, *args, **kwargs): + original_fn(recorder, *args, **kwargs) + PatchFastai.__metrics_names = ( + ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] + ) + PatchFastai.__metrics_names += recorder.metrics_names + + @staticmethod + def _on_backward_end(original_fn, recorder, *args, **kwargs): + def report_model_stats(series, value): + logger.report_scalar("model_stats_gradients", series, value, iteration) + + original_fn(recorder, *args, **kwargs) + gradients = [ + x.grad.clone().detach().cpu() + for x in recorder.learn.model.parameters() + if x.grad is not None + ] + if len(gradients) == 0: + return + iteration = kwargs.get("iteration") + norms = [x.data.norm() for x in gradients] + logger = PatchFastai.__main_task.get_logger() + for name, val in zip( + [ + "avg_norm", + "median_norm", + "max_norm", + "min_norm", + "num_zeros", + "avg_gradient", + "median_gradient", + "max_gradient", + "min_gradient", + ], + [ + sum(norms) / len(gradients), + statistics.median(norms), + max(norms), + min(norms), + sum( + (np.asarray(x) == 0.0).sum() + for x in [x.data.data.cpu().numpy() for x in gradients] + ), + sum(x.data.mean() for x in gradients) / len(gradients), + statistics.median(x.data.median() for x in gradients), + max(x.data.max() for x in gradients), + min(x.data.min() for x in gradients), + ], + ): + report_model_stats(name, val) + + @staticmethod + def _on_epoch_end(original_fn, recorder, *args, **kwargs): + original_fn(recorder, *args, **kwargs) + logger = PatchFastai.__main_task.get_logger() + iteration = kwargs.get("iteration") + for series, value in zip( + PatchFastai.__metrics_names, + [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), + ): + logger.report_scalar("metrics", series, value, iteration) + PatchFastai.__main_task.flush() + + @staticmethod + def _on_batch_end(original_fn, recorder, *args, **kwargs): + original_fn(recorder, *args, **kwargs) + if kwargs.get("iteration") == 0 or not kwargs.get("train"): + return + logger = PatchFastai.__main_task.get_logger() + logger.report_scalar( + "metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration") + ) + gradient_hist_helper = WeightsGradientHistHelper(logger) + iteration = kwargs.get("iteration") + params = [ + (name, values.clone().detach().cpu()) + for (name, values) in recorder.model.named_parameters() + ] + for (name, values) in params: + gradient_hist_helper.add_histogram( + title="model_weights", + series="model_weights/" + name, + step=iteration, + hist_data=values, + ) diff --git a/trains/task.py b/trains/task.py index 0cf33db6..1705f302 100644 --- a/trains/task.py +++ b/trains/task.py @@ -7,6 +7,7 @@ import time from argparse import ArgumentParser from tempfile import mkstemp + try: # noinspection PyCompatibility from collections.abc import Callable, Sequence as CollectionsSequence @@ -30,6 +31,7 @@ from .backend_interface.util import get_single_result, exact_match_regex, make_m from .binding.absl_bind import PatchAbsl from .binding.artifacts import Artifacts, Artifact from .binding.environ_bind import EnvironmentBind, PatchOsFork +from .binding.frameworks.fastai_bind import PatchFastai from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO from .binding.frameworks.tensorflow_bind import TensorflowBinding from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO @@ -469,6 +471,8 @@ class Task(_Task): PatchPyTorchModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): PatchXGBoostModelIO.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True): + PatchFastai.update_current_task(task) if auto_resource_monitoring and not is_sub_process_task_id: task._resource_monitor = ResourceMonitor( task, report_mem_used_per_process=not config.get(