mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix TF 2.4 keras load/save model
This commit is contained in:
parent
dacf097ebb
commit
737ca91d2a
@ -1438,9 +1438,11 @@ class PatchKerasModelIO(object):
|
|||||||
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
||||||
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
||||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
||||||
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, ]
|
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
||||||
|
None,
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional]
|
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||||
|
|
||||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||||
@ -1465,7 +1467,14 @@ class PatchKerasModelIO(object):
|
|||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa
|
import tensorflow # noqa
|
||||||
from tensorflow.python.keras import models as keras_saving # noqa
|
from tensorflow.python.keras import models as keras_saving_legacy # noqa
|
||||||
|
except ImportError:
|
||||||
|
keras_saving_legacy = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import tensorflow # noqa
|
||||||
|
from tensorflow.keras import models as keras_saving # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
keras_saving = None
|
keras_saving = None
|
||||||
|
|
||||||
@ -1474,13 +1483,16 @@ class PatchKerasModelIO(object):
|
|||||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||||
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
||||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
||||||
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, ]
|
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
||||||
|
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving, Functional]
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
|
Network, Sequential, keras_saving, Functional, keras_saving_legacy]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_io_calls(Network, Sequential, keras_saving, Functional):
|
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None):
|
||||||
try:
|
try:
|
||||||
if Sequential is not None:
|
if Sequential is not None:
|
||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
@ -1519,6 +1531,13 @@ class PatchKerasModelIO(object):
|
|||||||
if keras_saving is not None:
|
if keras_saving is not None:
|
||||||
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
||||||
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
||||||
|
|
||||||
|
if keras_saving_legacy is not None:
|
||||||
|
keras_saving_legacy.save_model = _patched_call(
|
||||||
|
keras_saving_legacy.save_model, PatchKerasModelIO._save_model)
|
||||||
|
keras_saving_legacy.load_model = _patched_call(
|
||||||
|
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user