Extend support for Keras/Tensorflow mix binding

This commit is contained in:
allegroai 2019-06-16 02:28:27 +03:00
parent 3112769ad9
commit e452709492

View File

@ -1047,7 +1047,8 @@ class PatchTensorFlowEager(object):
class PatchKerasModelIO(object):
__main_task = None
__patched = None
__patched_keras = None
__patched_tensorflow = None
@staticmethod
def update_current_task(task, **kwargs):
@ -1058,7 +1059,7 @@ class PatchKerasModelIO(object):
@staticmethod
def _patch_model_checkpoint():
if 'keras' in sys.modules:
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
try:
from keras.engine.network import Network
except ImportError:
@ -1071,8 +1072,17 @@ class PatchKerasModelIO(object):
from keras import models as keras_saving
except ImportError:
keras_saving = None
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
if 'tensorflow' in sys.modules:
# check that we are not patching anything twice
if PatchKerasModelIO.__patched_tensorflow:
PatchKerasModelIO.__patched_keras = [
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,]
else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
@ -1091,29 +1101,34 @@ class PatchKerasModelIO(object):
from tensorflow.python.keras import models as keras_saving
except ImportError:
keras_saving = None
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
if PatchKerasModelIO.__patched_keras:
PatchKerasModelIO.__patched_tensorflow = [
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,]
else:
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@staticmethod
def _patch_io_calls(Network, Sequential, keras_saving):
try:
# only patch once
if not PatchKerasModelIO.__patched:
PatchKerasModelIO.__patched = True
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))