diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index 5e4007cb..1a7c92e4 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -38,6 +38,12 @@ class PatchPyTorchModelIO(PatchBaseModelIO): import torch # noqa torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) + # noinspection PyBroadException + try: + # noinspection PyProtectedMember + torch.jit._script.RecursiveScriptModule.save = _patched_call(torch.jit._script.RecursiveScriptModule.save, PatchPyTorchModelIO._save) + except BaseException: + pass # no need to worry about recursive calls, _patched_call takes care of that if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):