Fix model reporting in tensorflow 2.13 does not work properly (#1112)

This commit is contained in:
allegroai 2023-09-08 22:14:53 +03:00
parent cd61efe6df
commit c922c40d13
2 changed files with 60 additions and 4 deletions

View File

@ -1589,6 +1589,11 @@ class PatchKerasModelIO(object):
from keras import models as keras_saving # noqa from keras import models as keras_saving # noqa
except ImportError: except ImportError:
keras_saving = None keras_saving = None
try:
from keras.src.saving import saving_api as keras_saving_v3
except ImportError:
keras_saving_v3 = None
# check that we are not patching anything twice # check that we are not patching anything twice
if PatchKerasModelIO.__patched_tensorflow: if PatchKerasModelIO.__patched_tensorflow:
PatchKerasModelIO.__patched_keras = [ PatchKerasModelIO.__patched_keras = [
@ -1598,9 +1603,10 @@ class PatchKerasModelIO(object):
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
None, None,
None, None,
keras_saving_v3
] ]
else: else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None] PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None, keras_saving_v3]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
@ -1643,6 +1649,8 @@ class PatchKerasModelIO(object):
except ImportError: except ImportError:
keras_hdf5 = None keras_hdf5 = None
keras_saving_v3 = None
if PatchKerasModelIO.__patched_keras: if PatchKerasModelIO.__patched_keras:
PatchKerasModelIO.__patched_tensorflow = [ PatchKerasModelIO.__patched_tensorflow = [
Network if PatchKerasModelIO.__patched_keras[0] != Network else None, Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
@ -1651,14 +1659,23 @@ class PatchKerasModelIO(object):
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None, keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None, keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
keras_saving_v3 if PatchKerasModelIO.__patched_keras[6] != keras_saving_v3 else None,
] ]
else: else:
PatchKerasModelIO.__patched_tensorflow = [ PatchKerasModelIO.__patched_tensorflow = [
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5] Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5, keras_saving_v3]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@staticmethod @staticmethod
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None): def _patch_io_calls(
Network,
Sequential,
keras_saving,
Functional,
keras_saving_legacy=None,
keras_hdf5=None,
keras_saving_v3=None
):
try: try:
if Sequential is not None: if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config, Sequential._updated_config = _patched_call(Sequential._updated_config,
@ -1718,6 +1735,9 @@ class PatchKerasModelIO(object):
keras_hdf5.save_model_to_hdf5 = _patched_call( keras_hdf5.save_model_to_hdf5 = _patched_call(
keras_hdf5.save_model_to_hdf5, PatchKerasModelIO._save_model) keras_hdf5.save_model_to_hdf5, PatchKerasModelIO._save_model)
if keras_saving_v3 is not None:
keras_saving_v3.save_model = _patched_call(keras_saving_v3.save_model, PatchKerasModelIO._save_model)
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -2058,6 +2078,11 @@ class PatchTensorflowModelIO(object):
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write) Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
except Exception: except Exception:
pass pass
# noinspection PyBroadException
try:
Checkpoint._write = _patched_call(Checkpoint._write, PatchTensorflowModelIO._ckpt_write)
except Exception:
pass
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
@ -2231,7 +2256,10 @@ class PatchTensorflow2ModelIO(object):
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa import tensorflow # noqa
from tensorflow.python.training.tracking import util # noqa try:
from tensorflow.python.checkpoint.checkpoint import TrackableSaver
except ImportError:
from tensorflow.python.training.tracking.util import TrackableSaver # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,

View File

@ -0,0 +1,28 @@
import numpy as np
import keras
from clearml import Task
def get_model():
# Create a simple model.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
return model
Task.init(project_name="examples", task_name="keras_v3")
model = get_model()
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)
model.save("my_model.keras")
reconstructed_model = keras.models.load_model("my_model.keras")
np.testing.assert_allclose(
model.predict(test_input), reconstructed_model.predict(test_input)
)