Fix issue #1249 pytorch-lightning patches (#1254)

This commit is contained in:
Andrew Gardner 2024-05-10 07:27:11 -05:00 committed by GitHub
parent 2fbd86415c
commit 66a7f5616c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,9 @@
import sys
from typing import Any, Callable, Literal
import six
import threading
import importlib
from pathlib2 import Path
@ -108,6 +110,65 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
del PatchPyTorchModelIO._checkpoint_filename[tid]
return ret
@staticmethod
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]):
try:
pytorch_lightning = importlib.import_module(lightning_name)
except ImportError:
# lightning is not installed
# Nothing to do
return
if lightning_name == "lightning":
pytorch_lightning = pytorch_lightning.pytorch
def patch_method(cls: type, method_name: str,
patched_method: Callable[..., Any]) -> None:
"""
Patch a method of a class if it exists.
Otherwise, no effect.
"""
try:
method = getattr(cls, method_name)
except AttributeError:
# the method is not defined on the given class
pass
else:
setattr(cls, method_name,
_patched_call(method, patched_method))
patch_method(pytorch_lightning.trainer.Trainer, "save_checkpoint",
PatchPyTorchModelIO._save)
patch_method(pytorch_lightning.trainer.Trainer, "restore",
PatchPyTorchModelIO._load_from_obj)
try:
checkpoint_connector = pytorch_lightning.trainer.connectors.checkpoint_connector
except AttributeError:
# checkpoint_connector does not yet exist; lightning version is < 0.10.0
# Nothing left to do
return
try:
CheckpointConnector = checkpoint_connector._CheckpointConnector
except AttributeError:
# CheckpointConnector has not yet been made protected
# lighting version is < 2.0.0
try:
CheckpointConnector = checkpoint_connector.CheckpointConnector
except AttributeError:
# Unexpected future breaking change in lightning
# No way to automatically handle
return
patch_method(CheckpointConnector, "save_checkpoint",
PatchPyTorchModelIO._save)
patch_method(CheckpointConnector, "restore",
PatchPyTorchModelIO._load_from_obj)
@staticmethod
def _patch_lightning_io():
if PatchPyTorchModelIO.__patched_lightning:
@ -118,41 +179,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
PatchPyTorchModelIO.__patched_lightning = True
# noinspection PyBroadException
try:
import lightning # noqa
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
) # noqa
lightning.pytorch.trainer.Trainer.restore = _patched_call(
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
) # noqa
except ImportError:
pass
except Exception:
pass
# noinspection PyBroadException
try:
import lightning # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save,
) # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj,
) # noqa
except ImportError:
pass
except Exception:
pass
PatchPyTorchModelIO._patch_lightning_io_internal("lightning")
@staticmethod
def _patch_pytorch_lightning_io():
@ -164,39 +191,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
PatchPyTorchModelIO.__patched_pytorch_lightning = True
# noinspection PyBroadException
try:
import pytorch_lightning # noqa
pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa
pytorch_lightning.trainer.Trainer.restore = _patched_call(
pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa
except ImportError:
pass
except Exception:
pass
# noinspection PyBroadException
try:
import pytorch_lightning # noqa
# noinspection PyUnresolvedReferences
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
_patched_call(
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save) # noqa
# noinspection PyUnresolvedReferences
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
_patched_call(
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj) # noqa
except ImportError:
pass
except Exception:
pass
PatchPyTorchModelIO._patch_lightning_io_internal("pytorch_lightning")
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
@ -334,4 +329,4 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
def __get_cached_checkpoint_filename():
tid = threading.current_thread().ident
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
return checkpoint_filename or None
return checkpoint_filename or None