Fix TF 2.4 keras load/save model

This commit is contained in:
allegroai 2021-03-11 09:41:14 +02:00
parent dacf097ebb
commit 737ca91d2a

View File

@ -1438,9 +1438,11 @@ class PatchKerasModelIO(object):
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, ]
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
None,
]
else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional]
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
@ -1465,7 +1467,14 @@ class PatchKerasModelIO(object):
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
from tensorflow.python.keras import models as keras_saving # noqa
from tensorflow.python.keras import models as keras_saving_legacy # noqa
except ImportError:
keras_saving_legacy = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
from tensorflow.keras import models as keras_saving # noqa
except ImportError:
keras_saving = None
@ -1474,13 +1483,16 @@ class PatchKerasModelIO(object):
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, ]
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
]
else:
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving, Functional]
PatchKerasModelIO.__patched_tensorflow = [
Network, Sequential, keras_saving, Functional, keras_saving_legacy]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@staticmethod
def _patch_io_calls(Network, Sequential, keras_saving, Functional):
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None):
try:
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
@ -1519,6 +1531,13 @@ class PatchKerasModelIO(object):
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
if keras_saving_legacy is not None:
keras_saving_legacy.save_model = _patched_call(
keras_saving_legacy.save_model, PatchKerasModelIO._save_model)
keras_saving_legacy.load_model = _patched_call(
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))