mirror of
https://github.com/clearml/clearml
synced 2025-04-14 20:42:55 +00:00
Improve pytorch model saving better, add support for mmvc
This commit is contained in:
parent
3e5d50e15d
commit
63e7cbab30
@ -44,6 +44,13 @@ def _patched_call(original_fn, patched_fn):
|
|||||||
return _inner_patch
|
return _inner_patch
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_call_no_recursion_guard(original_fn, patched_fn):
|
||||||
|
def _inner_patch(*args, **kwargs):
|
||||||
|
return patched_fn(original_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
return _inner_patch
|
||||||
|
|
||||||
|
|
||||||
class _Empty(object):
|
class _Empty(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.trains_in_model = None
|
self.trains_in_model = None
|
||||||
@ -159,7 +166,7 @@ class WeightsFileHandler(object):
|
|||||||
If the callback was already added, return the existing handle.
|
If the callback was already added, return the existing handle.
|
||||||
|
|
||||||
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
|
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
|
||||||
:return True if callback removed, False otherwise
|
:return: True if callback removed, False otherwise
|
||||||
"""
|
"""
|
||||||
return cls._remove_callback(handle, cls._model_post_callbacks)
|
return cls._remove_callback(handle, cls._model_post_callbacks)
|
||||||
|
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
import threading
|
||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from ...binding.frameworks.base_bind import PatchBaseModelIO
|
from ...binding.frameworks.base_bind import PatchBaseModelIO
|
||||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
from ..frameworks import _patched_call, _patched_call_no_recursion_guard, WeightsFileHandler, _Empty
|
||||||
from ..import_bind import PostImportHookPatching
|
from ..import_bind import PostImportHookPatching
|
||||||
from ...config import running_remotely
|
from ...config import running_remotely
|
||||||
from ...model import Framework
|
from ...model import Framework
|
||||||
@ -12,8 +14,11 @@ from ...model import Framework
|
|||||||
|
|
||||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||||
_current_task = None
|
_current_task = None
|
||||||
|
_checkpoint_filename = {}
|
||||||
__patched = None
|
__patched = None
|
||||||
__patched_lightning = None
|
__patched_lightning = None
|
||||||
|
__patched_mmcv = None
|
||||||
|
__default_checkpoint_filename_counter = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
@ -22,6 +27,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
return
|
return
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
PatchPyTorchModelIO._patch_model_io()
|
||||||
PatchPyTorchModelIO._patch_lightning_io()
|
PatchPyTorchModelIO._patch_lightning_io()
|
||||||
|
PatchPyTorchModelIO._patch_mmcv()
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||||
|
|
||||||
@ -65,6 +71,41 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching pytorch')
|
pass # print('Failed patching pytorch')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_mmcv():
|
||||||
|
if PatchPyTorchModelIO.__patched_mmcv:
|
||||||
|
return
|
||||||
|
if "mmcv" not in sys.modules:
|
||||||
|
return
|
||||||
|
PatchPyTorchModelIO.__patched_mmcv = True
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from mmcv.runner import epoch_based_runner, iter_based_runner
|
||||||
|
|
||||||
|
# we don't want the recursion check here because it guards pytorch's patched save functions
|
||||||
|
# which we need in order to log the saved model/checkpoint
|
||||||
|
epoch_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
|
||||||
|
epoch_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
|
||||||
|
)
|
||||||
|
iter_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
|
||||||
|
iter_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs):
|
||||||
|
# note that mmcv.runner.save_checkpoint doesn't return anything, hence the need for this
|
||||||
|
# patch function, but we return from it just in case this changes in the future
|
||||||
|
if not PatchPyTorchModelIO._current_task:
|
||||||
|
return original_fn(model, filename, *args, **kwargs)
|
||||||
|
tid = threading.current_thread().ident
|
||||||
|
PatchPyTorchModelIO._checkpoint_filename[tid] = filename
|
||||||
|
ret = original_fn(model, filename, *args, **kwargs)
|
||||||
|
del PatchPyTorchModelIO._checkpoint_filename[tid]
|
||||||
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_lightning_io():
|
def _patch_lightning_io():
|
||||||
if PatchPyTorchModelIO.__patched_lightning:
|
if PatchPyTorchModelIO.__patched_lightning:
|
||||||
@ -144,9 +185,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
filename = f.name
|
filename = f.name
|
||||||
else:
|
else:
|
||||||
filename = None
|
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = None
|
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -241,3 +282,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __create_default_filename():
|
||||||
|
tid = threading.current_thread().ident
|
||||||
|
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
||||||
|
if checkpoint_filename:
|
||||||
|
return checkpoint_filename
|
||||||
|
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
|
||||||
|
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
|
||||||
|
return "default_{}_{}".format(tid, counter)
|
||||||
|
Loading…
Reference in New Issue
Block a user