mirror of
https://github.com/clearml/clearml
synced 2025-04-03 20:41:07 +00:00
Improve pytorch model saving better, add support for mmvc
This commit is contained in:
parent
3e5d50e15d
commit
63e7cbab30
@ -44,6 +44,13 @@ def _patched_call(original_fn, patched_fn):
|
||||
return _inner_patch
|
||||
|
||||
|
||||
def _patched_call_no_recursion_guard(original_fn, patched_fn):
|
||||
def _inner_patch(*args, **kwargs):
|
||||
return patched_fn(original_fn, *args, **kwargs)
|
||||
|
||||
return _inner_patch
|
||||
|
||||
|
||||
class _Empty(object):
|
||||
def __init__(self):
|
||||
self.trains_in_model = None
|
||||
@ -159,7 +166,7 @@ class WeightsFileHandler(object):
|
||||
If the callback was already added, return the existing handle.
|
||||
|
||||
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
|
||||
:return True if callback removed, False otherwise
|
||||
:return: True if callback removed, False otherwise
|
||||
"""
|
||||
return cls._remove_callback(handle, cls._model_post_callbacks)
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
import sys
|
||||
|
||||
import six
|
||||
import threading
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from ...binding.frameworks.base_bind import PatchBaseModelIO
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..frameworks import _patched_call, _patched_call_no_recursion_guard, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import Framework
|
||||
@ -12,8 +14,11 @@ from ...model import Framework
|
||||
|
||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
_current_task = None
|
||||
_checkpoint_filename = {}
|
||||
__patched = None
|
||||
__patched_lightning = None
|
||||
__patched_mmcv = None
|
||||
__default_checkpoint_filename_counter = {}
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
@ -22,6 +27,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
return
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PatchPyTorchModelIO._patch_lightning_io()
|
||||
PatchPyTorchModelIO._patch_mmcv()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||
|
||||
@ -65,6 +71,41 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
except Exception:
|
||||
pass # print('Failed patching pytorch')
|
||||
|
||||
@staticmethod
|
||||
def _patch_mmcv():
|
||||
if PatchPyTorchModelIO.__patched_mmcv:
|
||||
return
|
||||
if "mmcv" not in sys.modules:
|
||||
return
|
||||
PatchPyTorchModelIO.__patched_mmcv = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from mmcv.runner import epoch_based_runner, iter_based_runner
|
||||
|
||||
# we don't want the recursion check here because it guards pytorch's patched save functions
|
||||
# which we need in order to log the saved model/checkpoint
|
||||
epoch_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
|
||||
epoch_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
|
||||
)
|
||||
iter_based_runner.save_checkpoint = _patched_call_no_recursion_guard(
|
||||
iter_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs):
|
||||
# note that mmcv.runner.save_checkpoint doesn't return anything, hence the need for this
|
||||
# patch function, but we return from it just in case this changes in the future
|
||||
if not PatchPyTorchModelIO._current_task:
|
||||
return original_fn(model, filename, *args, **kwargs)
|
||||
tid = threading.current_thread().ident
|
||||
PatchPyTorchModelIO._checkpoint_filename[tid] = filename
|
||||
ret = original_fn(model, filename, *args, **kwargs)
|
||||
del PatchPyTorchModelIO._checkpoint_filename[tid]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _patch_lightning_io():
|
||||
if PatchPyTorchModelIO.__patched_lightning:
|
||||
@ -144,9 +185,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||
except Exception:
|
||||
filename = None
|
||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
@ -241,3 +282,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def __create_default_filename():
|
||||
tid = threading.current_thread().ident
|
||||
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
||||
if checkpoint_filename:
|
||||
return checkpoint_filename
|
||||
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
|
||||
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
|
||||
return "default_{}_{}".format(tid, counter)
|
||||
|
Loading…
Reference in New Issue
Block a user