mirror of
https://github.com/clearml/clearml
synced 2025-04-24 08:14:25 +00:00
parent
2fbd86415c
commit
66a7f5616c
@ -1,7 +1,9 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import six
|
||||
import threading
|
||||
import importlib
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
@ -108,6 +110,65 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
del PatchPyTorchModelIO._checkpoint_filename[tid]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]):
|
||||
|
||||
try:
|
||||
pytorch_lightning = importlib.import_module(lightning_name)
|
||||
except ImportError:
|
||||
# lightning is not installed
|
||||
# Nothing to do
|
||||
return
|
||||
if lightning_name == "lightning":
|
||||
pytorch_lightning = pytorch_lightning.pytorch
|
||||
|
||||
def patch_method(cls: type, method_name: str,
|
||||
patched_method: Callable[..., Any]) -> None:
|
||||
"""
|
||||
Patch a method of a class if it exists.
|
||||
|
||||
Otherwise, no effect.
|
||||
"""
|
||||
try:
|
||||
method = getattr(cls, method_name)
|
||||
except AttributeError:
|
||||
# the method is not defined on the given class
|
||||
pass
|
||||
else:
|
||||
setattr(cls, method_name,
|
||||
_patched_call(method, patched_method))
|
||||
|
||||
patch_method(pytorch_lightning.trainer.Trainer, "save_checkpoint",
|
||||
PatchPyTorchModelIO._save)
|
||||
|
||||
patch_method(pytorch_lightning.trainer.Trainer, "restore",
|
||||
PatchPyTorchModelIO._load_from_obj)
|
||||
|
||||
try:
|
||||
checkpoint_connector = pytorch_lightning.trainer.connectors.checkpoint_connector
|
||||
except AttributeError:
|
||||
# checkpoint_connector does not yet exist; lightning version is < 0.10.0
|
||||
# Nothing left to do
|
||||
return
|
||||
|
||||
try:
|
||||
CheckpointConnector = checkpoint_connector._CheckpointConnector
|
||||
except AttributeError:
|
||||
# CheckpointConnector has not yet been made protected
|
||||
# lighting version is < 2.0.0
|
||||
try:
|
||||
CheckpointConnector = checkpoint_connector.CheckpointConnector
|
||||
except AttributeError:
|
||||
# Unexpected future breaking change in lightning
|
||||
# No way to automatically handle
|
||||
return
|
||||
|
||||
patch_method(CheckpointConnector, "save_checkpoint",
|
||||
PatchPyTorchModelIO._save)
|
||||
|
||||
patch_method(CheckpointConnector, "restore",
|
||||
PatchPyTorchModelIO._load_from_obj)
|
||||
|
||||
@staticmethod
|
||||
def _patch_lightning_io():
|
||||
if PatchPyTorchModelIO.__patched_lightning:
|
||||
@ -118,41 +179,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
|
||||
PatchPyTorchModelIO.__patched_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
|
||||
) # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.restore = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||
PatchPyTorchModelIO._save,
|
||||
) # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||
PatchPyTorchModelIO._load_from_obj,
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
PatchPyTorchModelIO._patch_lightning_io_internal("lightning")
|
||||
|
||||
@staticmethod
|
||||
def _patch_pytorch_lightning_io():
|
||||
@ -164,39 +191,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
|
||||
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import pytorch_lightning # noqa
|
||||
|
||||
pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
|
||||
pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa
|
||||
|
||||
pytorch_lightning.trainer.Trainer.restore = _patched_call(
|
||||
pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import pytorch_lightning # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
|
||||
_patched_call(
|
||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||
PatchPyTorchModelIO._save) # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
|
||||
_patched_call(
|
||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||
PatchPyTorchModelIO._load_from_obj) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
PatchPyTorchModelIO._patch_lightning_io_internal("pytorch_lightning")
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
@ -334,4 +329,4 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
def __get_cached_checkpoint_filename():
|
||||
tid = threading.current_thread().ident
|
||||
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
||||
return checkpoint_filename or None
|
||||
return checkpoint_filename or None
|
||||
|
Loading…
Reference in New Issue
Block a user