mirror of
https://github.com/clearml/clearml
synced 2025-04-08 06:34:37 +00:00
Improve pytorch ignite integration
This commit is contained in:
parent
a5b1ed0330
commit
aa61fa3f06
@ -32,10 +32,23 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
|
||||||
import torch
|
import torch
|
||||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||||
|
|
||||||
|
# no need to worry about recursive calls, _patched_call takes care of that
|
||||||
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
|
||||||
|
torch.serialization._save = _patched_call(
|
||||||
|
torch.serialization._save, PatchPyTorchModelIO._save)
|
||||||
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
|
||||||
|
torch.serialization._load = _patched_call(
|
||||||
|
torch.serialization._load, PatchPyTorchModelIO._load)
|
||||||
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
|
||||||
|
torch.serialization._legacy_save = _patched_call(
|
||||||
|
torch.serialization._legacy_save, PatchPyTorchModelIO._save)
|
||||||
|
if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
|
||||||
|
torch.serialization._legacy_load = _patched_call(
|
||||||
|
torch.serialization._legacy_load, PatchPyTorchModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -44,6 +57,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
if not PatchPyTorchModelIO.__main_task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -73,17 +87,18 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
# give the model a descriptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_name = Path(filename).stem
|
model_name = Path(filename).stem if filename is not None else None
|
||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
WeightsFileHandler.create_output_model(
|
||||||
singlefile=True, model_name=model_name)
|
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
# if there is no main task or this is a nested call
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
if not PatchPyTorchModelIO.__main_task:
|
||||||
return original_fn(f, *args, **kwargs)
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
@ -104,14 +119,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
filename = WeightsFileHandler.restore_weights_file(
|
||||||
PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
WeightsFileHandler.restore_weights_file(
|
||||||
PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -119,4 +134,5 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
model.trains_in_model = empty.trains_in_model
|
model.trains_in_model = empty.trains_in_model
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user