mirror of
https://github.com/clearml/clearml
synced 2025-02-12 15:45:25 +00:00
Support Keras restructuring for Network, Model and Sequential
This commit is contained in:
parent
5094ede309
commit
be099f42f9
@ -1303,6 +1303,10 @@ class PatchKerasModelIO(object):
|
|||||||
from keras.engine.network import Network
|
from keras.engine.network import Network
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Network = None
|
Network = None
|
||||||
|
try:
|
||||||
|
from keras.engine.functional import Functional
|
||||||
|
except ImportError:
|
||||||
|
Functional = None
|
||||||
try:
|
try:
|
||||||
from keras.engine.sequential import Sequential
|
from keras.engine.sequential import Sequential
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -1316,9 +1320,10 @@ class PatchKerasModelIO(object):
|
|||||||
PatchKerasModelIO.__patched_keras = [
|
PatchKerasModelIO.__patched_keras = [
|
||||||
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
||||||
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
||||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, ]
|
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
||||||
|
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, ]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving]
|
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||||
|
|
||||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||||
@ -1328,6 +1333,12 @@ class PatchKerasModelIO(object):
|
|||||||
from tensorflow.python.keras.engine.network import Network
|
from tensorflow.python.keras.engine.network import Network
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Network = None
|
Network = None
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import tensorflow # noqa: F401
|
||||||
|
from tensorflow.python.keras.engine.functional import Functional
|
||||||
|
except ImportError:
|
||||||
|
Functional = None
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa: F811
|
||||||
@ -1345,13 +1356,14 @@ class PatchKerasModelIO(object):
|
|||||||
PatchKerasModelIO.__patched_tensorflow = [
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||||
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
||||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, ]
|
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
||||||
|
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving]
|
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving, Functional]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_io_calls(Network, Sequential, keras_saving):
|
def _patch_io_calls(Network, Sequential, keras_saving, Functional):
|
||||||
try:
|
try:
|
||||||
if Sequential is not None:
|
if Sequential is not None:
|
||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
@ -1374,6 +1386,17 @@ class PatchKerasModelIO(object):
|
|||||||
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
||||||
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
||||||
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
||||||
|
elif Functional is not None:
|
||||||
|
Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config)
|
||||||
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
|
||||||
|
PatchKerasModelIO._from_config))
|
||||||
|
else:
|
||||||
|
Functional.from_config = _patched_call(Functional.from_config, PatchKerasModelIO._from_config)
|
||||||
|
Functional.save = _patched_call(Functional.save, PatchKerasModelIO._save)
|
||||||
|
Functional.save_weights = _patched_call(Functional.save_weights, PatchKerasModelIO._save_weights)
|
||||||
|
Functional.load_weights = _patched_call(Functional.load_weights, PatchKerasModelIO._load_weights)
|
||||||
|
|
||||||
if keras_saving is not None:
|
if keras_saving is not None:
|
||||||
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
||||||
|
Loading…
Reference in New Issue
Block a user