mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add fastai binding support
This commit is contained in:
		
							parent
							
								
									88d88e914d
								
							
						
					
					
						commit
						d642639890
					
				
							
								
								
									
										134
									
								
								trains/binding/frameworks/fastai_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								trains/binding/frameworks/fastai_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | ||||
| import statistics | ||||
| import sys | ||||
| 
 | ||||
| import numpy as np | ||||
| 
 | ||||
| from . import _patched_call | ||||
| from .tensorflow_bind import WeightsGradientHistHelper | ||||
| from ..import_bind import PostImportHookPatching | ||||
| from ...debugging.log import LoggerRoot | ||||
| 
 | ||||
| 
 | ||||
| class PatchFastai(object): | ||||
|     __metrics_names = None | ||||
|     __main_task = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def update_current_task(task, **kwargs): | ||||
|         PatchFastai.__main_task = task | ||||
|         PatchFastai._patch_model_callback() | ||||
|         PostImportHookPatching.add_on_import( | ||||
|             "fastai", PatchFastai._patch_model_callback | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patch_model_callback(): | ||||
|         if "fastai" in sys.modules: | ||||
|             try: | ||||
|                 from fastai.basic_train import Recorder | ||||
| 
 | ||||
|                 Recorder.on_batch_end = _patched_call( | ||||
|                     Recorder.on_batch_end, PatchFastai._on_batch_end | ||||
|                 ) | ||||
|                 Recorder.on_backward_end = _patched_call( | ||||
|                     Recorder.on_backward_end, PatchFastai._on_backward_end | ||||
|                 ) | ||||
|                 Recorder.on_epoch_end = _patched_call( | ||||
|                     Recorder.on_epoch_end, PatchFastai._on_epoch_end | ||||
|                 ) | ||||
|                 Recorder.on_train_begin = _patched_call( | ||||
|                     Recorder.on_train_begin, PatchFastai._on_train_begin | ||||
|                 ) | ||||
| 
 | ||||
|             except ImportError: | ||||
|                 pass | ||||
|             except Exception as ex: | ||||
|                 LoggerRoot.get_base_logger(PatchFastai).debug(str(ex)) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_train_begin(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         PatchFastai.__metrics_names = ( | ||||
|             ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] | ||||
|         ) | ||||
|         PatchFastai.__metrics_names += recorder.metrics_names | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_backward_end(original_fn, recorder, *args, **kwargs): | ||||
|         def report_model_stats(series, value): | ||||
|             logger.report_scalar("model_stats_gradients", series, value, iteration) | ||||
| 
 | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         gradients = [ | ||||
|             x.grad.clone().detach().cpu() | ||||
|             for x in recorder.learn.model.parameters() | ||||
|             if x.grad is not None | ||||
|         ] | ||||
|         if len(gradients) == 0: | ||||
|             return | ||||
|         iteration = kwargs.get("iteration") | ||||
|         norms = [x.data.norm() for x in gradients] | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         for name, val in zip( | ||||
|             [ | ||||
|                 "avg_norm", | ||||
|                 "median_norm", | ||||
|                 "max_norm", | ||||
|                 "min_norm", | ||||
|                 "num_zeros", | ||||
|                 "avg_gradient", | ||||
|                 "median_gradient", | ||||
|                 "max_gradient", | ||||
|                 "min_gradient", | ||||
|             ], | ||||
|             [ | ||||
|                 sum(norms) / len(gradients), | ||||
|                 statistics.median(norms), | ||||
|                 max(norms), | ||||
|                 min(norms), | ||||
|                 sum( | ||||
|                     (np.asarray(x) == 0.0).sum() | ||||
|                     for x in [x.data.data.cpu().numpy() for x in gradients] | ||||
|                 ), | ||||
|                 sum(x.data.mean() for x in gradients) / len(gradients), | ||||
|                 statistics.median(x.data.median() for x in gradients), | ||||
|                 max(x.data.max() for x in gradients), | ||||
|                 min(x.data.min() for x in gradients), | ||||
|             ], | ||||
|         ): | ||||
|             report_model_stats(name, val) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_epoch_end(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         iteration = kwargs.get("iteration") | ||||
|         for series, value in zip( | ||||
|             PatchFastai.__metrics_names, | ||||
|             [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), | ||||
|         ): | ||||
|             logger.report_scalar("metrics", series, value, iteration) | ||||
|         PatchFastai.__main_task.flush() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_batch_end(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         if kwargs.get("iteration") == 0 or not kwargs.get("train"): | ||||
|             return | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         logger.report_scalar( | ||||
|             "metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration") | ||||
|         ) | ||||
|         gradient_hist_helper = WeightsGradientHistHelper(logger) | ||||
|         iteration = kwargs.get("iteration") | ||||
|         params = [ | ||||
|             (name, values.clone().detach().cpu()) | ||||
|             for (name, values) in recorder.model.named_parameters() | ||||
|         ] | ||||
|         for (name, values) in params: | ||||
|             gradient_hist_helper.add_histogram( | ||||
|                 title="model_weights", | ||||
|                 series="model_weights/" + name, | ||||
|                 step=iteration, | ||||
|                 hist_data=values, | ||||
|             ) | ||||
| @ -7,6 +7,7 @@ import time | ||||
| from argparse import ArgumentParser | ||||
| from tempfile import mkstemp | ||||
| 
 | ||||
| 
 | ||||
| try: | ||||
|     # noinspection PyCompatibility | ||||
|     from collections.abc import Callable, Sequence as CollectionsSequence | ||||
| @ -30,6 +31,7 @@ from .backend_interface.util import get_single_result, exact_match_regex, make_m | ||||
| from .binding.absl_bind import PatchAbsl | ||||
| from .binding.artifacts import Artifacts, Artifact | ||||
| from .binding.environ_bind import EnvironmentBind, PatchOsFork | ||||
| from .binding.frameworks.fastai_bind import PatchFastai | ||||
| from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO | ||||
| from .binding.frameworks.tensorflow_bind import TensorflowBinding | ||||
| from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO | ||||
| @ -469,6 +471,8 @@ class Task(_Task): | ||||
|                     PatchPyTorchModelIO.update_current_task(task) | ||||
|                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): | ||||
|                     PatchXGBoostModelIO.update_current_task(task) | ||||
|                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True): | ||||
|                     PatchFastai.update_current_task(task) | ||||
|             if auto_resource_monitoring and not is_sub_process_task_id: | ||||
|                 task._resource_monitor = ResourceMonitor( | ||||
|                     task, report_mem_used_per_process=not config.get( | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user