mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Refactor fastai bind
This commit is contained in:
		
							parent
							
								
									00ccadf591
								
							
						
					
					
						commit
						093477cb35
					
				| @ -1,4 +1,3 @@ | ||||
| import statistics | ||||
| import sys | ||||
| 
 | ||||
| import numpy as np | ||||
| @ -10,16 +9,14 @@ from ...debugging.log import LoggerRoot | ||||
| 
 | ||||
| 
 | ||||
| class PatchFastai(object): | ||||
|     __metrics_names = None | ||||
|     __metrics_names = None  # TODO: STORE ON OBJECT OR IN LOOKUP BASED ON OBJECT ID | ||||
|     __main_task = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def update_current_task(task, **kwargs): | ||||
|     def update_current_task(task, **_): | ||||
|         PatchFastai.__main_task = task | ||||
|         PatchFastai._patch_model_callback() | ||||
|         PostImportHookPatching.add_on_import( | ||||
|             "fastai", PatchFastai._patch_model_callback | ||||
|         ) | ||||
|         PostImportHookPatching.add_on_import("fastai", PatchFastai._patch_model_callback) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patch_model_callback(): | ||||
| @ -27,19 +24,10 @@ class PatchFastai(object): | ||||
|             try: | ||||
|                 from fastai.basic_train import Recorder | ||||
| 
 | ||||
|                 Recorder.on_batch_end = _patched_call( | ||||
|                     Recorder.on_batch_end, PatchFastai._on_batch_end | ||||
|                 ) | ||||
|                 Recorder.on_backward_end = _patched_call( | ||||
|                     Recorder.on_backward_end, PatchFastai._on_backward_end | ||||
|                 ) | ||||
|                 Recorder.on_epoch_end = _patched_call( | ||||
|                     Recorder.on_epoch_end, PatchFastai._on_epoch_end | ||||
|                 ) | ||||
|                 Recorder.on_train_begin = _patched_call( | ||||
|                     Recorder.on_train_begin, PatchFastai._on_train_begin | ||||
|                 ) | ||||
| 
 | ||||
|                 Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastai._on_batch_end) | ||||
|                 Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastai._on_backward_end) | ||||
|                 Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastai._on_epoch_end) | ||||
|                 Recorder.on_train_begin = _patched_call(Recorder.on_train_begin, PatchFastai._on_train_begin) | ||||
|             except ImportError: | ||||
|                 pass | ||||
|             except Exception as ex: | ||||
| @ -48,87 +36,106 @@ class PatchFastai(object): | ||||
|     @staticmethod | ||||
|     def _on_train_begin(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         PatchFastai.__metrics_names = ( | ||||
|             ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] | ||||
|         ) | ||||
|         PatchFastai.__metrics_names += recorder.metrics_names | ||||
|         if not PatchFastai.__main_task: | ||||
|             return | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             PatchFastai.__metrics_names = ["train_loss"] if recorder.no_val else ["train_loss", "valid_loss"] | ||||
|             PatchFastai.__metrics_names += recorder.metrics_names | ||||
|         except Exception as ex: | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_backward_end(original_fn, recorder, *args, **kwargs): | ||||
|         def report_model_stats(series, value): | ||||
|             logger.report_scalar("model_stats_gradients", series, value, iteration) | ||||
|         def count_zeros(gradient): | ||||
|             n = gradient.data.data.cpu().numpy() | ||||
|             return n.size - n.count_nonzero() | ||||
| 
 | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         gradients = [ | ||||
|             x.grad.clone().detach().cpu() | ||||
|             for x in recorder.learn.model.parameters() | ||||
|             if x.grad is not None | ||||
|         ] | ||||
|         if len(gradients) == 0: | ||||
| 
 | ||||
|         if not PatchFastai.__main_task: | ||||
|             return | ||||
|         iteration = kwargs.get("iteration") | ||||
|         norms = [x.data.norm() for x in gradients] | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         for name, val in zip( | ||||
|             [ | ||||
|                 "avg_norm", | ||||
|                 "median_norm", | ||||
|                 "max_norm", | ||||
|                 "min_norm", | ||||
|                 "num_zeros", | ||||
|                 "avg_gradient", | ||||
|                 "median_gradient", | ||||
|                 "max_gradient", | ||||
|                 "min_gradient", | ||||
|             ], | ||||
|             [ | ||||
|                 sum(norms) / len(gradients), | ||||
|                 statistics.median(norms), | ||||
|                 max(norms), | ||||
|                 min(norms), | ||||
|                 sum( | ||||
|                     (np.asarray(x) == 0.0).sum() | ||||
|                     for x in [x.data.data.cpu().numpy() for x in gradients] | ||||
|                 ), | ||||
|                 sum(x.data.mean() for x in gradients) / len(gradients), | ||||
|                 statistics.median(x.data.median() for x in gradients), | ||||
|                 max(x.data.max() for x in gradients), | ||||
|                 min(x.data.min() for x in gradients), | ||||
|             ], | ||||
|         ): | ||||
|             report_model_stats(name, val) | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             gradients = [ | ||||
|                 x.grad.clone().detach().cpu() for x in recorder.learn.model.parameters() if x.grad is not None | ||||
|             ] | ||||
|             if len(gradients) == 0: | ||||
|                 return | ||||
| 
 | ||||
|             # TODO: Check computation! | ||||
|             gradient_stats = np.array([ | ||||
|                 (x.data.norm(), count_zeros(x), x.data.mean(), x.data.median(), x.data.max(), x.data.min()) | ||||
|                 for x in gradients]) | ||||
|             stats_report = dict( | ||||
|                 avg_norm=np.mean(gradient_stats[:, 0]), | ||||
|                 median_norm=np.median(gradient_stats[:, 0]), | ||||
|                 max_norm=np.max(gradient_stats[:, 0]), | ||||
|                 min_norm=np.min(gradient_stats[:, 0]), | ||||
|                 num_zeros=gradient_stats[:, 1].sum(), | ||||
|                 avg_gradient=gradient_stats[:, 2].mean(), | ||||
|                 median_gradient=gradient_stats[:, 3].median(), | ||||
|                 max_gradient=gradient_stats[:, 4].max(), | ||||
|                 min_gradient=gradient_stats[:, 5].min(), | ||||
|             ) | ||||
| 
 | ||||
|             logger = PatchFastai.__main_task.get_logger() | ||||
|             iteration = kwargs.get("iteration", 0) | ||||
|             for name, val in stats_report.items(): | ||||
|                 logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration) | ||||
|         except Exception as ex: | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_epoch_end(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         iteration = kwargs.get("iteration") | ||||
|         for series, value in zip( | ||||
|             PatchFastai.__metrics_names, | ||||
|             [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), | ||||
|         ): | ||||
|             logger.report_scalar("metrics", series, value, iteration) | ||||
|         PatchFastai.__main_task.flush() | ||||
|         if not PatchFastai.__main_task: | ||||
|             return | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             logger = PatchFastai.__main_task.get_logger() | ||||
|             iteration = kwargs.get("iteration") | ||||
|             for series, value in zip( | ||||
|                     PatchFastai.__metrics_names, | ||||
|                     [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), | ||||
|             ): | ||||
|                 logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration) | ||||
|             PatchFastai.__main_task.flush() | ||||
|         except Exception: | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _on_batch_end(original_fn, recorder, *args, **kwargs): | ||||
|         original_fn(recorder, *args, **kwargs) | ||||
|         if kwargs.get("iteration") == 0 or not kwargs.get("train"): | ||||
|         if not PatchFastai.__main_task: | ||||
|             return | ||||
|         logger = PatchFastai.__main_task.get_logger() | ||||
|         logger.report_scalar( | ||||
|             "metrics", "train_loss", kwargs.get("last_loss"), kwargs.get("iteration") | ||||
|         ) | ||||
|         gradient_hist_helper = WeightsGradientHistHelper(logger) | ||||
|         iteration = kwargs.get("iteration") | ||||
|         params = [ | ||||
|             (name, values.clone().detach().cpu()) | ||||
|             for (name, values) in recorder.model.named_parameters() | ||||
|         ] | ||||
|         for (name, values) in params: | ||||
|             gradient_hist_helper.add_histogram( | ||||
|                 title="model_weights", | ||||
|                 series="model_weights/" + name, | ||||
|                 step=iteration, | ||||
|                 hist_data=values, | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             if kwargs.get("iteration") == 0 or not kwargs.get("train"): | ||||
|                 return | ||||
| 
 | ||||
|             logger = PatchFastai.__main_task.get_logger() | ||||
|             logger.report_scalar( | ||||
|                 title="metrics", | ||||
|                 series="train_loss", | ||||
|                 value=kwargs.get("last_loss", 0), | ||||
|                 iteration=kwargs.get("iteration", 0) | ||||
|             ) | ||||
|             gradient_hist_helper = WeightsGradientHistHelper(logger) | ||||
|             iteration = kwargs.get("iteration") | ||||
|             params = [ | ||||
|                 (name, values.clone().detach().cpu()) | ||||
|                 for (name, values) in recorder.model.named_parameters() | ||||
|             ] | ||||
|             for (name, values) in params: | ||||
|                 gradient_hist_helper.add_histogram( | ||||
|                     title="model_weights", | ||||
|                     series="model_weights/" + name, | ||||
|                     step=iteration, | ||||
|                     hist_data=values, | ||||
|                 ) | ||||
|         except Exception: | ||||
|             pass | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai