From be099f42f9e589559015c498dbc13908579a3ff5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 29 Sep 2020 19:27:43 +0300 Subject: [PATCH] Support Keras restructuring for Network, Model and Sequential --- trains/binding/frameworks/tensorflow_bind.py | 33 +++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index ebb92732..805937f0 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1303,6 +1303,10 @@ class PatchKerasModelIO(object): from keras.engine.network import Network except ImportError: Network = None + try: + from keras.engine.functional import Functional + except ImportError: + Functional = None try: from keras.engine.sequential import Sequential except ImportError: @@ -1316,9 +1320,10 @@ class PatchKerasModelIO(object): PatchKerasModelIO.__patched_keras = [ Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None, Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None, - keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, ] + keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, + Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, ] else: - PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving] + PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: @@ -1328,6 +1333,12 @@ class PatchKerasModelIO(object): from tensorflow.python.keras.engine.network import Network except ImportError: Network = None + try: + # hack: make sure tensorflow.__init__ is called + import tensorflow # noqa: F401 + from tensorflow.python.keras.engine.functional import Functional + except ImportError: + Functional = None try: # hack: make sure tensorflow.__init__ is called import tensorflow # noqa: F811 @@ -1345,13 +1356,14 @@ class PatchKerasModelIO(object): PatchKerasModelIO.__patched_tensorflow = [ Network if PatchKerasModelIO.__patched_keras[0] != Network else None, Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None, - keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, ] + keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, + Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,] else: - PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving] + PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving, Functional] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) @staticmethod - def _patch_io_calls(Network, Sequential, keras_saving): + def _patch_io_calls(Network, Sequential, keras_saving, Functional): try: if Sequential is not None: Sequential._updated_config = _patched_call(Sequential._updated_config, @@ -1374,6 +1386,17 @@ class PatchKerasModelIO(object): Network.save = _patched_call(Network.save, PatchKerasModelIO._save) Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) + elif Functional is not None: + Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config) + if hasattr(Sequential.from_config, '__func__'): + # noinspection PyUnresolvedReferences + Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__, + PatchKerasModelIO._from_config)) + else: + Functional.from_config = _patched_call(Functional.from_config, PatchKerasModelIO._from_config) + Functional.save = _patched_call(Functional.save, PatchKerasModelIO._save) + Functional.save_weights = _patched_call(Functional.save_weights, PatchKerasModelIO._save_weights) + Functional.load_weights = _patched_call(Functional.load_weights, PatchKerasModelIO._load_weights) if keras_saving is not None: keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)