mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Improve pytorch model saving better, add support for mmvc
This commit is contained in:
		
							parent
							
								
									3e5d50e15d
								
							
						
					
					
						commit
						63e7cbab30
					
				| @ -44,6 +44,13 @@ def _patched_call(original_fn, patched_fn): | ||||
|     return _inner_patch | ||||
| 
 | ||||
| 
 | ||||
| def _patched_call_no_recursion_guard(original_fn, patched_fn): | ||||
|     def _inner_patch(*args, **kwargs): | ||||
|         return patched_fn(original_fn, *args, **kwargs) | ||||
| 
 | ||||
|     return _inner_patch | ||||
| 
 | ||||
| 
 | ||||
| class _Empty(object): | ||||
|     def __init__(self): | ||||
|         self.trains_in_model = None | ||||
| @ -159,7 +166,7 @@ class WeightsFileHandler(object): | ||||
|         If the callback was already added, return the existing handle. | ||||
| 
 | ||||
|         :param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback | ||||
|         :return True if callback removed, False otherwise | ||||
|         :return: True if callback removed, False otherwise | ||||
|         """ | ||||
|         return cls._remove_callback(handle, cls._model_post_callbacks) | ||||
| 
 | ||||
|  | ||||
| @ -1,10 +1,12 @@ | ||||
| import sys | ||||
| 
 | ||||
| import six | ||||
| import threading | ||||
| 
 | ||||
| from pathlib2 import Path | ||||
| 
 | ||||
| from ...binding.frameworks.base_bind import PatchBaseModelIO | ||||
| from ..frameworks import _patched_call, WeightsFileHandler, _Empty | ||||
| from ..frameworks import _patched_call, _patched_call_no_recursion_guard, WeightsFileHandler, _Empty | ||||
| from ..import_bind import PostImportHookPatching | ||||
| from ...config import running_remotely | ||||
| from ...model import Framework | ||||
| @ -12,8 +14,11 @@ from ...model import Framework | ||||
| 
 | ||||
| class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
|     _current_task = None | ||||
|     _checkpoint_filename = {} | ||||
|     __patched = None | ||||
|     __patched_lightning = None | ||||
|     __patched_mmcv = None | ||||
|     __default_checkpoint_filename_counter = {} | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def update_current_task(task, **_): | ||||
| @ -22,6 +27,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
|             return | ||||
|         PatchPyTorchModelIO._patch_model_io() | ||||
|         PatchPyTorchModelIO._patch_lightning_io() | ||||
|         PatchPyTorchModelIO._patch_mmcv() | ||||
|         PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) | ||||
|         PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io) | ||||
| 
 | ||||
| @ -65,6 +71,41 @@ class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
|         except Exception: | ||||
|             pass  # print('Failed patching pytorch') | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patch_mmcv(): | ||||
|         if PatchPyTorchModelIO.__patched_mmcv: | ||||
|             return | ||||
|         if "mmcv" not in sys.modules: | ||||
|             return | ||||
|         PatchPyTorchModelIO.__patched_mmcv = True | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             from mmcv.runner import epoch_based_runner, iter_based_runner | ||||
| 
 | ||||
|             # we don't want the recursion check here because it guards pytorch's patched save functions | ||||
|             # which we need in order to log the saved model/checkpoint | ||||
|             epoch_based_runner.save_checkpoint = _patched_call_no_recursion_guard( | ||||
|                 epoch_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint | ||||
|             ) | ||||
|             iter_based_runner.save_checkpoint = _patched_call_no_recursion_guard( | ||||
|                 iter_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint | ||||
|             ) | ||||
|         except Exception: | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs): | ||||
|         # note that mmcv.runner.save_checkpoint doesn't return anything, hence the need for this | ||||
|         # patch function, but we return from it just in case this changes in the future | ||||
|         if not PatchPyTorchModelIO._current_task: | ||||
|             return original_fn(model, filename, *args, **kwargs) | ||||
|         tid = threading.current_thread().ident | ||||
|         PatchPyTorchModelIO._checkpoint_filename[tid] = filename | ||||
|         ret = original_fn(model, filename, *args, **kwargs) | ||||
|         del PatchPyTorchModelIO._checkpoint_filename[tid] | ||||
|         return ret | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patch_lightning_io(): | ||||
|         if PatchPyTorchModelIO.__patched_lightning: | ||||
| @ -144,9 +185,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
| 
 | ||||
|                 filename = f.name | ||||
|             else: | ||||
|                 filename = None | ||||
|                 filename = PatchPyTorchModelIO.__create_default_filename() | ||||
|         except Exception: | ||||
|             filename = None | ||||
|             filename = PatchPyTorchModelIO.__create_default_filename() | ||||
| 
 | ||||
|         # give the model a descriptive name based on the file name | ||||
|         # noinspection PyBroadException | ||||
| @ -241,3 +282,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
|                 pass | ||||
| 
 | ||||
|         return model | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __create_default_filename(): | ||||
|         tid = threading.current_thread().ident | ||||
|         checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid) | ||||
|         if checkpoint_filename: | ||||
|             return checkpoint_filename | ||||
|         counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0) | ||||
|         PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1 | ||||
|         return "default_{}_{}".format(tid, counter) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai