diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py index f783bc11..61706d78 100644 --- a/trains/binding/frameworks/pytorch_bind.py +++ b/trains/binding/frameworks/pytorch_bind.py @@ -13,12 +13,15 @@ from ...model import Framework class PatchPyTorchModelIO(PatchBaseModelIO): __main_task = None __patched = None + __patched_lightning = None @staticmethod def update_current_task(task, **_): PatchPyTorchModelIO.__main_task = task PatchPyTorchModelIO._patch_model_io() + PatchPyTorchModelIO._patch_lightning_io() PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) + PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io) @staticmethod def _patch_model_io(): @@ -32,28 +35,72 @@ class PatchPyTorchModelIO(PatchBaseModelIO): # noinspection PyBroadException try: - import torch + import torch # noqa torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) # no need to worry about recursive calls, _patched_call takes care of that if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'): torch.serialization._save = _patched_call( - torch.serialization._save, PatchPyTorchModelIO._save) + torch.serialization._save, PatchPyTorchModelIO._save) # noqa if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'): torch.serialization._load = _patched_call( - torch.serialization._load, PatchPyTorchModelIO._load) + torch.serialization._load, PatchPyTorchModelIO._load) # noqa if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'): torch.serialization._legacy_save = _patched_call( - torch.serialization._legacy_save, PatchPyTorchModelIO._save) + torch.serialization._legacy_save, PatchPyTorchModelIO._save) # noqa if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'): torch.serialization._legacy_load = _patched_call( - torch.serialization._legacy_load, PatchPyTorchModelIO._load) + torch.serialization._legacy_load, PatchPyTorchModelIO._load) # noqa except ImportError: pass except Exception: pass # print('Failed patching pytorch') + @staticmethod + def _patch_lightning_io(): + if PatchPyTorchModelIO.__patched_lightning: + return + + if 'pytorch_lightning' not in sys.modules: + return + + PatchPyTorchModelIO.__patched_lightning = True + + # noinspection PyBroadException + try: + import pytorch_lightning # noqa + + pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call( + pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa + + pytorch_lightning.trainer.Trainer.restore = _patched_call( + pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa + except ImportError: + pass + except Exception: + pass + + # noinspection PyBroadException + try: + import pytorch_lightning # noqa + + # noinspection PyUnresolvedReferences + pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \ + _patched_call( + pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint, + PatchPyTorchModelIO._save) # noqa + + # noinspection PyUnresolvedReferences + pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \ + _patched_call( + pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore, + PatchPyTorchModelIO._load_from_obj) # noqa + except ImportError: + pass + except Exception: + pass + @staticmethod def _save(original_fn, obj, f, *args, **kwargs): ret = original_fn(obj, f, *args, **kwargs) @@ -136,3 +183,44 @@ class PatchPyTorchModelIO(PatchBaseModelIO): pass return model + + @staticmethod + def _load_from_obj(original_fn, obj, f, *args, **kwargs): + # if there is no main task or this is a nested call + if not PatchPyTorchModelIO.__main_task: + return original_fn(obj, f, *args, **kwargs) + + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + filename = f.name + else: + filename = None + except Exception: + filename = None + + # register input model + empty = _Empty() + # Hack: disabled + if False and running_remotely(): + filename = WeightsFileHandler.restore_weights_file( + empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + model = original_fn(obj, filename or f, *args, **kwargs) + else: + # try to load model before registering, in case we fail + model = original_fn(obj, f, *args, **kwargs) + WeightsFileHandler.restore_weights_file( + empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + + return model