mirror of
https://github.com/clearml/clearml
synced 2025-04-06 05:35:32 +00:00
Extend support for Keras/Tensorflow mix binding
This commit is contained in:
parent
3112769ad9
commit
e452709492
@ -1047,7 +1047,8 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
class PatchKerasModelIO(object):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
__patched_keras = None
|
||||
__patched_tensorflow = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
@ -1058,7 +1059,7 @@ class PatchKerasModelIO(object):
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_checkpoint():
|
||||
if 'keras' in sys.modules:
|
||||
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
|
||||
try:
|
||||
from keras.engine.network import Network
|
||||
except ImportError:
|
||||
@ -1071,8 +1072,17 @@ class PatchKerasModelIO(object):
|
||||
from keras import models as keras_saving
|
||||
except ImportError:
|
||||
keras_saving = None
|
||||
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
|
||||
if 'tensorflow' in sys.modules:
|
||||
# check that we are not patching anything twice
|
||||
if PatchKerasModelIO.__patched_tensorflow:
|
||||
PatchKerasModelIO.__patched_keras = [
|
||||
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
||||
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||
|
||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow
|
||||
@ -1091,14 +1101,19 @@ class PatchKerasModelIO(object):
|
||||
from tensorflow.python.keras import models as keras_saving
|
||||
except ImportError:
|
||||
keras_saving = None
|
||||
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
|
||||
|
||||
if PatchKerasModelIO.__patched_keras:
|
||||
PatchKerasModelIO.__patched_tensorflow = [
|
||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||
|
||||
@staticmethod
|
||||
def _patch_io_calls(Network, Sequential, keras_saving):
|
||||
try:
|
||||
# only patch once
|
||||
if not PatchKerasModelIO.__patched:
|
||||
PatchKerasModelIO.__patched = True
|
||||
if Sequential is not None:
|
||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||
PatchKerasModelIO._updated_config)
|
||||
|
Loading…
Reference in New Issue
Block a user