diff --git a/trains/utilities/frameworks.py b/trains/utilities/frameworks.py index bb4724d3..3da630bc 100644 --- a/trains/utilities/frameworks.py +++ b/trains/utilities/frameworks.py @@ -1047,7 +1047,8 @@ class PatchTensorFlowEager(object): class PatchKerasModelIO(object): __main_task = None - __patched = None + __patched_keras = None + __patched_tensorflow = None @staticmethod def update_current_task(task, **kwargs): @@ -1058,7 +1059,7 @@ class PatchKerasModelIO(object): @staticmethod def _patch_model_checkpoint(): - if 'keras' in sys.modules: + if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras: try: from keras.engine.network import Network except ImportError: @@ -1071,8 +1072,17 @@ class PatchKerasModelIO(object): from keras import models as keras_saving except ImportError: keras_saving = None - PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving) - if 'tensorflow' in sys.modules: + # check that we are not patching anything twice + if PatchKerasModelIO.__patched_tensorflow: + PatchKerasModelIO.__patched_keras = [ + Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None, + Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None, + keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,] + else: + PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving] + PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) + + if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: try: # hack: make sure tensorflow.__init__ is called import tensorflow @@ -1091,29 +1101,34 @@ class PatchKerasModelIO(object): from tensorflow.python.keras import models as keras_saving except ImportError: keras_saving = None - PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving) + + if PatchKerasModelIO.__patched_keras: + PatchKerasModelIO.__patched_tensorflow = [ + Network if PatchKerasModelIO.__patched_keras[0] != Network else None, + Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None, + keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,] + else: + PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving] + PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) @staticmethod def _patch_io_calls(Network, Sequential, keras_saving): try: - # only patch once - if not PatchKerasModelIO.__patched: - PatchKerasModelIO.__patched = True - if Sequential is not None: - Sequential._updated_config = _patched_call(Sequential._updated_config, - PatchKerasModelIO._updated_config) - Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config) + if Sequential is not None: + Sequential._updated_config = _patched_call(Sequential._updated_config, + PatchKerasModelIO._updated_config) + Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config) - if Network is not None: - Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config) - Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config) - Network.save = _patched_call(Network.save, PatchKerasModelIO._save) - Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) - Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) + if Network is not None: + Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config) + Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config) + Network.save = _patched_call(Network.save, PatchKerasModelIO._save) + Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) + Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) - if keras_saving is not None: - keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) - keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) + if keras_saving is not None: + keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) + keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) except Exception as ex: getLogger(TrainsFrameworkAdapter).warning(str(ex))