From 66a7f5616ce83d9bd2d05ed6cdfbd2a6fcddffbc Mon Sep 17 00:00:00 2001 From: Andrew Gardner Date: Fri, 10 May 2024 07:27:11 -0500 Subject: [PATCH] Fix issue #1249 pytorch-lightning patches (#1254) --- clearml/binding/frameworks/pytorch_bind.py | 133 ++++++++++----------- 1 file changed, 64 insertions(+), 69 deletions(-) diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index ff4d561d..fbe1de42 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -1,7 +1,9 @@ import sys +from typing import Any, Callable, Literal import six import threading +import importlib from pathlib2 import Path @@ -108,6 +110,65 @@ class PatchPyTorchModelIO(PatchBaseModelIO): del PatchPyTorchModelIO._checkpoint_filename[tid] return ret + @staticmethod + def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]): + + try: + pytorch_lightning = importlib.import_module(lightning_name) + except ImportError: + # lightning is not installed + # Nothing to do + return + if lightning_name == "lightning": + pytorch_lightning = pytorch_lightning.pytorch + + def patch_method(cls: type, method_name: str, + patched_method: Callable[..., Any]) -> None: + """ + Patch a method of a class if it exists. + + Otherwise, no effect. + """ + try: + method = getattr(cls, method_name) + except AttributeError: + # the method is not defined on the given class + pass + else: + setattr(cls, method_name, + _patched_call(method, patched_method)) + + patch_method(pytorch_lightning.trainer.Trainer, "save_checkpoint", + PatchPyTorchModelIO._save) + + patch_method(pytorch_lightning.trainer.Trainer, "restore", + PatchPyTorchModelIO._load_from_obj) + + try: + checkpoint_connector = pytorch_lightning.trainer.connectors.checkpoint_connector + except AttributeError: + # checkpoint_connector does not yet exist; lightning version is < 0.10.0 + # Nothing left to do + return + + try: + CheckpointConnector = checkpoint_connector._CheckpointConnector + except AttributeError: + # CheckpointConnector has not yet been made protected + # lighting version is < 2.0.0 + try: + CheckpointConnector = checkpoint_connector.CheckpointConnector + except AttributeError: + # Unexpected future breaking change in lightning + # No way to automatically handle + return + + patch_method(CheckpointConnector, "save_checkpoint", + PatchPyTorchModelIO._save) + + patch_method(CheckpointConnector, "restore", + PatchPyTorchModelIO._load_from_obj) + @staticmethod def _patch_lightning_io(): if PatchPyTorchModelIO.__patched_lightning: @@ -118,41 +179,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): PatchPyTorchModelIO.__patched_lightning = True - # noinspection PyBroadException - try: - import lightning # noqa - - lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call( - lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save - ) # noqa - - lightning.pytorch.trainer.Trainer.restore = _patched_call( - lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj - ) # noqa - except ImportError: - pass - except Exception: - pass - - # noinspection PyBroadException - try: - import lightning # noqa - - # noinspection PyUnresolvedReferences - lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call( - lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint, - PatchPyTorchModelIO._save, - ) # noqa - - # noinspection PyUnresolvedReferences - lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call( - lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore, - PatchPyTorchModelIO._load_from_obj, - ) # noqa - except ImportError: - pass - except Exception: - pass + PatchPyTorchModelIO._patch_lightning_io_internal("lightning") @staticmethod def _patch_pytorch_lightning_io(): @@ -164,39 +191,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): PatchPyTorchModelIO.__patched_pytorch_lightning = True - # noinspection PyBroadException - try: - import pytorch_lightning # noqa - - pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call( - pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa - - pytorch_lightning.trainer.Trainer.restore = _patched_call( - pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa - except ImportError: - pass - except Exception: - pass - - # noinspection PyBroadException - try: - import pytorch_lightning # noqa - - # noinspection PyUnresolvedReferences - pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \ - _patched_call( - pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint, - PatchPyTorchModelIO._save) # noqa - - # noinspection PyUnresolvedReferences - pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \ - _patched_call( - pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore, - PatchPyTorchModelIO._load_from_obj) # noqa - except ImportError: - pass - except Exception: - pass + PatchPyTorchModelIO._patch_lightning_io_internal("pytorch_lightning") @staticmethod def _save(original_fn, obj, f, *args, **kwargs): @@ -334,4 +329,4 @@ class PatchPyTorchModelIO(PatchBaseModelIO): def __get_cached_checkpoint_filename(): tid = threading.current_thread().ident checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid) - return checkpoint_filename or None \ No newline at end of file + return checkpoint_filename or None