mirror of
https://github.com/clearml/clearml
synced 2025-03-06 03:59:09 +00:00
Add PyTorch Lightning save/restore model binding (issue #212)
This commit is contained in:
parent
1a39973cb9
commit
64e10b2f62
@ -13,12 +13,15 @@ from ...model import Framework
|
|||||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
__patched_lightning = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchPyTorchModelIO.__main_task = task
|
PatchPyTorchModelIO.__main_task = task
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
PatchPyTorchModelIO._patch_model_io()
|
||||||
|
PatchPyTorchModelIO._patch_lightning_io()
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||||
|
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_io():
|
def _patch_model_io():
|
||||||
@ -32,28 +35,72 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch # noqa
|
||||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||||
|
|
||||||
# no need to worry about recursive calls, _patched_call takes care of that
|
# no need to worry about recursive calls, _patched_call takes care of that
|
||||||
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
|
||||||
torch.serialization._save = _patched_call(
|
torch.serialization._save = _patched_call(
|
||||||
torch.serialization._save, PatchPyTorchModelIO._save)
|
torch.serialization._save, PatchPyTorchModelIO._save) # noqa
|
||||||
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
|
||||||
torch.serialization._load = _patched_call(
|
torch.serialization._load = _patched_call(
|
||||||
torch.serialization._load, PatchPyTorchModelIO._load)
|
torch.serialization._load, PatchPyTorchModelIO._load) # noqa
|
||||||
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
|
||||||
torch.serialization._legacy_save = _patched_call(
|
torch.serialization._legacy_save = _patched_call(
|
||||||
torch.serialization._legacy_save, PatchPyTorchModelIO._save)
|
torch.serialization._legacy_save, PatchPyTorchModelIO._save) # noqa
|
||||||
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
|
||||||
torch.serialization._legacy_load = _patched_call(
|
torch.serialization._legacy_load = _patched_call(
|
||||||
torch.serialization._legacy_load, PatchPyTorchModelIO._load)
|
torch.serialization._legacy_load, PatchPyTorchModelIO._load) # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching pytorch')
|
pass # print('Failed patching pytorch')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_lightning_io():
|
||||||
|
if PatchPyTorchModelIO.__patched_lightning:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'pytorch_lightning' not in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
PatchPyTorchModelIO.__patched_lightning = True
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import pytorch_lightning # noqa
|
||||||
|
|
||||||
|
pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
|
||||||
|
pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa
|
||||||
|
|
||||||
|
pytorch_lightning.trainer.Trainer.restore = _patched_call(
|
||||||
|
pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import pytorch_lightning # noqa
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
|
||||||
|
_patched_call(
|
||||||
|
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||||
|
PatchPyTorchModelIO._save) # noqa
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
|
||||||
|
_patched_call(
|
||||||
|
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||||
|
PatchPyTorchModelIO._load_from_obj) # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
@ -136,3 +183,44 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
|
if not PatchPyTorchModelIO.__main_task:
|
||||||
|
return original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'as_posix'):
|
||||||
|
filename = f.as_posix()
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
except Exception:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
# Hack: disabled
|
||||||
|
if False and running_remotely():
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||||
|
model = original_fn(obj, filename or f, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(obj, f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user