mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Add support for new lightning namespace (#1033)
This commit is contained in:
parent
42320421a2
commit
ec51f35ec7
@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
_checkpoint_filename = {}
|
||||
__patched = None
|
||||
__patched_lightning = None
|
||||
__patched_pytorch_lightning = None
|
||||
__patched_mmcv = None
|
||||
|
||||
@staticmethod
|
||||
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
return
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PatchPyTorchModelIO._patch_lightning_io()
|
||||
PatchPyTorchModelIO._patch_pytorch_lightning_io()
|
||||
PatchPyTorchModelIO._patch_mmcv()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||
PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
if PatchPyTorchModelIO.__patched_lightning:
|
||||
return
|
||||
|
||||
if 'pytorch_lightning' not in sys.modules:
|
||||
if 'lightning' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
|
||||
) # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.restore = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||
PatchPyTorchModelIO._save,
|
||||
) # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||
PatchPyTorchModelIO._load_from_obj,
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _patch_pytorch_lightning_io():
|
||||
if PatchPyTorchModelIO.__patched_pytorch_lightning:
|
||||
return
|
||||
|
||||
if 'pytorch_lightning' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import pytorch_lightning # noqa
|
||||
|
Loading…
Reference in New Issue
Block a user