diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index fd113d4c..f7ae2822 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -1438,9 +1438,11 @@ class PatchKerasModelIO(object): Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None, Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None, keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, - Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, ] + Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, + None, + ] else: - PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional] + PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: @@ -1465,7 +1467,14 @@ class PatchKerasModelIO(object): try: # hack: make sure tensorflow.__init__ is called import tensorflow # noqa - from tensorflow.python.keras import models as keras_saving # noqa + from tensorflow.python.keras import models as keras_saving_legacy # noqa + except ImportError: + keras_saving_legacy = None + + try: + # hack: make sure tensorflow.__init__ is called + import tensorflow # noqa + from tensorflow.keras import models as keras_saving # noqa except ImportError: keras_saving = None @@ -1474,13 +1483,16 @@ class PatchKerasModelIO(object): Network if PatchKerasModelIO.__patched_keras[0] != Network else None, Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None, keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, - Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, ] + Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, + keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None, + ] else: - PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving, Functional] + PatchKerasModelIO.__patched_tensorflow = [ + Network, Sequential, keras_saving, Functional, keras_saving_legacy] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) @staticmethod - def _patch_io_calls(Network, Sequential, keras_saving, Functional): + def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None): try: if Sequential is not None: Sequential._updated_config = _patched_call(Sequential._updated_config, @@ -1519,6 +1531,13 @@ class PatchKerasModelIO(object): if keras_saving is not None: keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) + + if keras_saving_legacy is not None: + keras_saving_legacy.save_model = _patched_call( + keras_saving_legacy.save_model, PatchKerasModelIO._save_model) + keras_saving_legacy.load_model = _patched_call( + keras_saving_legacy.load_model, PatchKerasModelIO._load_model) + except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))