Improve pytorch model saving better, add support for mmvc

This commit is contained in:
allegroai 2022-09-15 15:59:34 +03:00
parent 3e5d50e15d
commit 63e7cbab30
2 changed files with 62 additions and 4 deletions

View File

@ -44,6 +44,13 @@ def _patched_call(original_fn, patched_fn):
return _inner_patch
def _patched_call_no_recursion_guard(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
return patched_fn(original_fn, *args, **kwargs)
return _inner_patch
class _Empty(object):
def __init__(self):
self.trains_in_model = None
@ -159,7 +166,7 @@ class WeightsFileHandler(object):
If the callback was already added, return the existing handle.
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
:return True if callback removed, False otherwise
:return: True if callback removed, False otherwise
"""
return cls._remove_callback(handle, cls._model_post_callbacks)

View File

@ -1,10 +1,12 @@
import sys
import six
import threading
from pathlib2 import Path
from ...binding.frameworks.base_bind import PatchBaseModelIO
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..frameworks import _patched_call, _patched_call_no_recursion_guard, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import Framework
@ -12,8 +14,11 @@ from ...model import Framework
class PatchPyTorchModelIO(PatchBaseModelIO):
_current_task = None
_checkpoint_filename = {}
__patched = None
__patched_lightning = None
__patched_mmcv = None
__default_checkpoint_filename_counter = {}
@staticmethod
def update_current_task(task, **_):
@ -22,6 +27,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return
PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io()
PatchPyTorchModelIO._patch_mmcv()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
@ -65,6 +71,41 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _patch_mmcv():
if PatchPyTorchModelIO.__patched_mmcv:
return
if "mmcv" not in sys.modules:
return
PatchPyTorchModelIO.__patched_mmcv = True
# noinspection PyBroadException
try:
from mmcv.runner import epoch_based_runner, iter_based_runner
# we don't want the recursion check here because it guards pytorch's patched save functions
# which we need in order to log the saved model/checkpoint
epoch_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
epoch_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
)
iter_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
iter_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
)
except Exception:
pass
@staticmethod
def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs):
# note that mmcv.runner.save_checkpoint doesn't return anything, hence the need for this
# patch function, but we return from it just in case this changes in the future
if not PatchPyTorchModelIO._current_task:
return original_fn(model, filename, *args, **kwargs)
tid = threading.current_thread().ident
PatchPyTorchModelIO._checkpoint_filename[tid] = filename
ret = original_fn(model, filename, *args, **kwargs)
del PatchPyTorchModelIO._checkpoint_filename[tid]
return ret
@staticmethod
def _patch_lightning_io():
if PatchPyTorchModelIO.__patched_lightning:
@ -144,9 +185,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
filename = f.name
else:
filename = None
filename = PatchPyTorchModelIO.__create_default_filename()
except Exception:
filename = None
filename = PatchPyTorchModelIO.__create_default_filename()
# give the model a descriptive name based on the file name
# noinspection PyBroadException
@ -241,3 +282,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
pass
return model
@staticmethod
def __create_default_filename():
tid = threading.current_thread().ident
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
if checkpoint_filename:
return checkpoint_filename
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
return "default_{}_{}".format(tid, counter)