Fix Pytorch ScriptModule autobind

This commit is contained in:
allegroai 2022-04-13 14:15:00 +03:00
parent 681d75a309
commit 6695c94fdd

View File

@ -38,6 +38,12 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
import torch # noqa import torch # noqa
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
torch.jit._script.RecursiveScriptModule.save = _patched_call(torch.jit._script.RecursiveScriptModule.save, PatchPyTorchModelIO._save)
except BaseException:
pass
# no need to worry about recursive calls, _patched_call takes care of that # no need to worry about recursive calls, _patched_call takes care of that
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'): if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):