mirror of
https://github.com/clearml/clearml
synced 2025-04-26 17:30:20 +00:00
parent
2fbd86415c
commit
66a7f5616c
@ -1,7 +1,9 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Callable, Literal
|
||||||
|
|
||||||
import six
|
import six
|
||||||
import threading
|
import threading
|
||||||
|
import importlib
|
||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
@ -108,6 +110,65 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
del PatchPyTorchModelIO._checkpoint_filename[tid]
|
del PatchPyTorchModelIO._checkpoint_filename[tid]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]):
|
||||||
|
|
||||||
|
try:
|
||||||
|
pytorch_lightning = importlib.import_module(lightning_name)
|
||||||
|
except ImportError:
|
||||||
|
# lightning is not installed
|
||||||
|
# Nothing to do
|
||||||
|
return
|
||||||
|
if lightning_name == "lightning":
|
||||||
|
pytorch_lightning = pytorch_lightning.pytorch
|
||||||
|
|
||||||
|
def patch_method(cls: type, method_name: str,
|
||||||
|
patched_method: Callable[..., Any]) -> None:
|
||||||
|
"""
|
||||||
|
Patch a method of a class if it exists.
|
||||||
|
|
||||||
|
Otherwise, no effect.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
method = getattr(cls, method_name)
|
||||||
|
except AttributeError:
|
||||||
|
# the method is not defined on the given class
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
setattr(cls, method_name,
|
||||||
|
_patched_call(method, patched_method))
|
||||||
|
|
||||||
|
patch_method(pytorch_lightning.trainer.Trainer, "save_checkpoint",
|
||||||
|
PatchPyTorchModelIO._save)
|
||||||
|
|
||||||
|
patch_method(pytorch_lightning.trainer.Trainer, "restore",
|
||||||
|
PatchPyTorchModelIO._load_from_obj)
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_connector = pytorch_lightning.trainer.connectors.checkpoint_connector
|
||||||
|
except AttributeError:
|
||||||
|
# checkpoint_connector does not yet exist; lightning version is < 0.10.0
|
||||||
|
# Nothing left to do
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
CheckpointConnector = checkpoint_connector._CheckpointConnector
|
||||||
|
except AttributeError:
|
||||||
|
# CheckpointConnector has not yet been made protected
|
||||||
|
# lighting version is < 2.0.0
|
||||||
|
try:
|
||||||
|
CheckpointConnector = checkpoint_connector.CheckpointConnector
|
||||||
|
except AttributeError:
|
||||||
|
# Unexpected future breaking change in lightning
|
||||||
|
# No way to automatically handle
|
||||||
|
return
|
||||||
|
|
||||||
|
patch_method(CheckpointConnector, "save_checkpoint",
|
||||||
|
PatchPyTorchModelIO._save)
|
||||||
|
|
||||||
|
patch_method(CheckpointConnector, "restore",
|
||||||
|
PatchPyTorchModelIO._load_from_obj)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_lightning_io():
|
def _patch_lightning_io():
|
||||||
if PatchPyTorchModelIO.__patched_lightning:
|
if PatchPyTorchModelIO.__patched_lightning:
|
||||||
@ -118,41 +179,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
PatchPyTorchModelIO.__patched_lightning = True
|
PatchPyTorchModelIO.__patched_lightning = True
|
||||||
|
|
||||||
# noinspection PyBroadException
|
PatchPyTorchModelIO._patch_lightning_io_internal("lightning")
|
||||||
try:
|
|
||||||
import lightning # noqa
|
|
||||||
|
|
||||||
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
|
|
||||||
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
lightning.pytorch.trainer.Trainer.restore = _patched_call(
|
|
||||||
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
|
|
||||||
) # noqa
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
import lightning # noqa
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
|
|
||||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
|
||||||
PatchPyTorchModelIO._save,
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
|
|
||||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
|
||||||
PatchPyTorchModelIO._load_from_obj,
|
|
||||||
) # noqa
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_pytorch_lightning_io():
|
def _patch_pytorch_lightning_io():
|
||||||
@ -164,39 +191,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
||||||
|
|
||||||
# noinspection PyBroadException
|
PatchPyTorchModelIO._patch_lightning_io_internal("pytorch_lightning")
|
||||||
try:
|
|
||||||
import pytorch_lightning # noqa
|
|
||||||
|
|
||||||
pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
|
|
||||||
pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa
|
|
||||||
|
|
||||||
pytorch_lightning.trainer.Trainer.restore = _patched_call(
|
|
||||||
pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
import pytorch_lightning # noqa
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
|
|
||||||
_patched_call(
|
|
||||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
|
||||||
PatchPyTorchModelIO._save) # noqa
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
|
|
||||||
_patched_call(
|
|
||||||
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
|
||||||
PatchPyTorchModelIO._load_from_obj) # noqa
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user