Improve pytorch ignite integration

This commit is contained in:
allegroai 2020-06-13 22:10:59 +03:00
parent a5b1ed0330
commit aa61fa3f06

View File

@ -32,10 +32,23 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
# no need to worry about recursive calls, _patched_call takes care of that
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
torch.serialization._save = _patched_call(
torch.serialization._save, PatchPyTorchModelIO._save)
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
torch.serialization._load = _patched_call(
torch.serialization._load, PatchPyTorchModelIO._load)
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
torch.serialization._legacy_save = _patched_call(
torch.serialization._legacy_save, PatchPyTorchModelIO._save)
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
torch.serialization._legacy_load = _patched_call(
torch.serialization._legacy_load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
@ -44,6 +57,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task:
return ret
@ -73,17 +87,18 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
model_name = Path(filename).stem if filename is not None else None
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
WeightsFileHandler.create_output_model(
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
@ -104,14 +119,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
empty = _Empty()
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
filename = WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
@ -119,4 +134,5 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model