diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py index da178967..f783bc11 100644 --- a/trains/binding/frameworks/pytorch_bind.py +++ b/trains/binding/frameworks/pytorch_bind.py @@ -32,10 +32,23 @@ class PatchPyTorchModelIO(PatchBaseModelIO): # noinspection PyBroadException try: - # hack: make sure tensorflow.__init__ is called import torch torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) + + # no need to worry about recursive calls, _patched_call takes care of that + if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'): + torch.serialization._save = _patched_call( + torch.serialization._save, PatchPyTorchModelIO._save) + if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'): + torch.serialization._load = _patched_call( + torch.serialization._load, PatchPyTorchModelIO._load) + if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'): + torch.serialization._legacy_save = _patched_call( + torch.serialization._legacy_save, PatchPyTorchModelIO._save) + if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'): + torch.serialization._legacy_load = _patched_call( + torch.serialization._legacy_load, PatchPyTorchModelIO._load) except ImportError: pass except Exception: @@ -44,6 +57,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): @staticmethod def _save(original_fn, obj, f, *args, **kwargs): ret = original_fn(obj, f, *args, **kwargs) + # if there is no main task or this is a nested call if not PatchPyTorchModelIO.__main_task: return ret @@ -73,17 +87,18 @@ class PatchPyTorchModelIO(PatchBaseModelIO): # give the model a descriptive name based on the file name # noinspection PyBroadException try: - model_name = Path(filename).stem + model_name = Path(filename).stem if filename is not None else None except Exception: model_name = None - WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, - singlefile=True, model_name=model_name) + WeightsFileHandler.create_output_model( + obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name) return ret @staticmethod def _load(original_fn, f, *args, **kwargs): + # if there is no main task or this is a nested call if not PatchPyTorchModelIO.__main_task: return original_fn(f, *args, **kwargs) @@ -104,14 +119,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO): empty = _Empty() # Hack: disabled if False and running_remotely(): - filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) + filename = WeightsFileHandler.restore_weights_file( + empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) model = original_fn(filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(f, *args, **kwargs) - WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) + WeightsFileHandler.restore_weights_file( + empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) if empty.trains_in_model: # noinspection PyBroadException @@ -119,4 +134,5 @@ class PatchPyTorchModelIO(PatchBaseModelIO): model.trains_in_model = empty.trains_in_model except Exception: pass + return model