Add support for new lightning namespace (#1033)

This commit is contained in:
Andreas Weinmann 2023-06-11 04:05:39 -07:00 committed by GitHub
parent 42320421a2
commit ec51f35ec7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
_checkpoint_filename = {} _checkpoint_filename = {}
__patched = None __patched = None
__patched_lightning = None __patched_lightning = None
__patched_pytorch_lightning = None
__patched_mmcv = None __patched_mmcv = None
@staticmethod @staticmethod
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return return
PatchPyTorchModelIO._patch_model_io() PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io() PatchPyTorchModelIO._patch_lightning_io()
PatchPyTorchModelIO._patch_pytorch_lightning_io()
PatchPyTorchModelIO._patch_mmcv() PatchPyTorchModelIO._patch_mmcv()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io) PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
@staticmethod @staticmethod
def _patch_model_io(): def _patch_model_io():
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if PatchPyTorchModelIO.__patched_lightning: if PatchPyTorchModelIO.__patched_lightning:
return return
if 'pytorch_lightning' not in sys.modules: if 'lightning' not in sys.modules:
return return
PatchPyTorchModelIO.__patched_lightning = True PatchPyTorchModelIO.__patched_lightning = True
# noinspection PyBroadException
try:
import lightning # noqa
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
) # noqa
lightning.pytorch.trainer.Trainer.restore = _patched_call(
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
) # noqa
except ImportError:
pass
except Exception:
pass
# noinspection PyBroadException
try:
import lightning # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save,
) # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj,
) # noqa
except ImportError:
pass
except Exception:
pass
@staticmethod
def _patch_pytorch_lightning_io():
if PatchPyTorchModelIO.__patched_pytorch_lightning:
return
if 'pytorch_lightning' not in sys.modules:
return
PatchPyTorchModelIO.__patched_pytorch_lightning = True
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import pytorch_lightning # noqa import pytorch_lightning # noqa