Fix incorrect module used

This commit is contained in:
allegroai 2023-09-16 20:53:47 +03:00
parent 1820423b02
commit 701582d5ff

View File

@ -2252,24 +2252,24 @@ class PatchTensorflow2ModelIO(object):
return
PatchTensorflow2ModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
try:
from tensorflow.python.checkpoint.checkpoint import TrackableSaver
except ImportError:
from tensorflow.python.training.tracking.util import TrackableSaver # noqa
from tensorflow.python.training.tracking import util # noqa
# noinspection PyBroadException
try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
PatchTensorflow2ModelIO._save)
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, PatchTensorflow2ModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
PatchTensorflow2ModelIO._restore)
util.TrackableSaver.restore = _patched_call(
util.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
)
except Exception:
pass
except ImportError:
@ -2277,6 +2277,32 @@ class PatchTensorflow2ModelIO(object):
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
from tensorflow.python.checkpoint import checkpoint
# noinspection PyBroadException
try:
checkpoint.TrackableSaver.save = _patched_call(
checkpoint.TrackableSaver.save, PatchTensorflow2ModelIO._save
)
except Exception:
pass
# noinspection PyBroadException
try:
checkpoint.TrackableSaver.restore = _patched_call(
checkpoint.TrackableSaver.restore, PatchTensorflow2ModelIO._restore
)
except Exception:
pass
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2.11')
@staticmethod
def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs)