mirror of
https://github.com/clearml/clearml
synced 2025-03-04 02:57:24 +00:00
Fix incorrect module used
This commit is contained in:
parent
1820423b02
commit
701582d5ff
@ -2252,24 +2252,24 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
PatchTensorflow2ModelIO.__patched = True
|
PatchTensorflow2ModelIO.__patched = True
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa
|
import tensorflow # noqa
|
||||||
try:
|
from tensorflow.python.training.tracking import util # noqa
|
||||||
from tensorflow.python.checkpoint.checkpoint import TrackableSaver
|
|
||||||
except ImportError:
|
|
||||||
from tensorflow.python.training.tracking.util import TrackableSaver # noqa
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, PatchTensorflow2ModelIO._save)
|
||||||
PatchTensorflow2ModelIO._save)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
|
util.TrackableSaver.restore = _patched_call(
|
||||||
PatchTensorflow2ModelIO._restore)
|
util.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -2277,6 +2277,32 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import tensorflow # noqa
|
||||||
|
from tensorflow.python.checkpoint import checkpoint
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
checkpoint.TrackableSaver.save = _patched_call(
|
||||||
|
checkpoint.TrackableSaver.save, PatchTensorflow2ModelIO._save
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
checkpoint.TrackableSaver.restore = _patched_call(
|
||||||
|
checkpoint.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2.11')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user