mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix incorrect module used
This commit is contained in:
parent
1820423b02
commit
701582d5ff
@ -2252,24 +2252,24 @@ class PatchTensorflow2ModelIO(object):
|
||||
return
|
||||
|
||||
PatchTensorflow2ModelIO.__patched = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa
|
||||
try:
|
||||
from tensorflow.python.checkpoint.checkpoint import TrackableSaver
|
||||
except ImportError:
|
||||
from tensorflow.python.training.tracking.util import TrackableSaver # noqa
|
||||
from tensorflow.python.training.tracking import util # noqa
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||
PatchTensorflow2ModelIO._save)
|
||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, PatchTensorflow2ModelIO._save)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
|
||||
PatchTensorflow2ModelIO._restore)
|
||||
util.TrackableSaver.restore = _patched_call(
|
||||
util.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
@ -2277,6 +2277,32 @@ class PatchTensorflow2ModelIO(object):
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.checkpoint import checkpoint
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
checkpoint.TrackableSaver.save = _patched_call(
|
||||
checkpoint.TrackableSaver.save, PatchTensorflow2ModelIO._save
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
checkpoint.TrackableSaver.restore = _patched_call(
|
||||
checkpoint.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2.11')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user