mirror of
https://github.com/clearml/clearml
synced 2025-04-23 07:45:24 +00:00
Fix model reporting in tensorflow 2.13 does not work properly (#1112)
This commit is contained in:
parent
cd61efe6df
commit
c922c40d13
@ -1589,6 +1589,11 @@ class PatchKerasModelIO(object):
|
|||||||
from keras import models as keras_saving # noqa
|
from keras import models as keras_saving # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
keras_saving = None
|
keras_saving = None
|
||||||
|
try:
|
||||||
|
from keras.src.saving import saving_api as keras_saving_v3
|
||||||
|
except ImportError:
|
||||||
|
keras_saving_v3 = None
|
||||||
|
|
||||||
# check that we are not patching anything twice
|
# check that we are not patching anything twice
|
||||||
if PatchKerasModelIO.__patched_tensorflow:
|
if PatchKerasModelIO.__patched_tensorflow:
|
||||||
PatchKerasModelIO.__patched_keras = [
|
PatchKerasModelIO.__patched_keras = [
|
||||||
@ -1598,9 +1603,10 @@ class PatchKerasModelIO(object):
|
|||||||
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
keras_saving_v3
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None]
|
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None, keras_saving_v3]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||||
|
|
||||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||||
@ -1643,6 +1649,8 @@ class PatchKerasModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
keras_hdf5 = None
|
keras_hdf5 = None
|
||||||
|
|
||||||
|
keras_saving_v3 = None
|
||||||
|
|
||||||
if PatchKerasModelIO.__patched_keras:
|
if PatchKerasModelIO.__patched_keras:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||||
@ -1651,14 +1659,23 @@ class PatchKerasModelIO(object):
|
|||||||
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
||||||
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
||||||
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
|
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
|
||||||
|
keras_saving_v3 if PatchKerasModelIO.__patched_keras[6] != keras_saving_v3 else None,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5]
|
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5, keras_saving_v3]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None):
|
def _patch_io_calls(
|
||||||
|
Network,
|
||||||
|
Sequential,
|
||||||
|
keras_saving,
|
||||||
|
Functional,
|
||||||
|
keras_saving_legacy=None,
|
||||||
|
keras_hdf5=None,
|
||||||
|
keras_saving_v3=None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
if Sequential is not None:
|
if Sequential is not None:
|
||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
@ -1718,6 +1735,9 @@ class PatchKerasModelIO(object):
|
|||||||
keras_hdf5.save_model_to_hdf5 = _patched_call(
|
keras_hdf5.save_model_to_hdf5 = _patched_call(
|
||||||
keras_hdf5.save_model_to_hdf5, PatchKerasModelIO._save_model)
|
keras_hdf5.save_model_to_hdf5, PatchKerasModelIO._save_model)
|
||||||
|
|
||||||
|
if keras_saving_v3 is not None:
|
||||||
|
keras_saving_v3.save_model = _patched_call(keras_saving_v3.save_model, PatchKerasModelIO._save_model)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@ -2058,6 +2078,11 @@ class PatchTensorflowModelIO(object):
|
|||||||
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
|
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
Checkpoint._write = _patched_call(Checkpoint._write, PatchTensorflowModelIO._ckpt_write)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -2231,7 +2256,10 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa
|
import tensorflow # noqa
|
||||||
from tensorflow.python.training.tracking import util # noqa
|
try:
|
||||||
|
from tensorflow.python.checkpoint.checkpoint import TrackableSaver
|
||||||
|
except ImportError:
|
||||||
|
from tensorflow.python.training.tracking.util import TrackableSaver # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||||
|
28
examples/frameworks/keras/keras_v3.py
Normal file
28
examples/frameworks/keras/keras_v3.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import numpy as np
|
||||||
|
import keras
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
|
||||||
|
def get_model():
|
||||||
|
# Create a simple model.
|
||||||
|
inputs = keras.Input(shape=(32,))
|
||||||
|
outputs = keras.layers.Dense(1)(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
|
||||||
|
return model
|
||||||
|
|
||||||
|
Task.init(project_name="examples", task_name="keras_v3")
|
||||||
|
|
||||||
|
model = get_model()
|
||||||
|
|
||||||
|
test_input = np.random.random((128, 32))
|
||||||
|
test_target = np.random.random((128, 1))
|
||||||
|
model.fit(test_input, test_target)
|
||||||
|
|
||||||
|
model.save("my_model.keras")
|
||||||
|
|
||||||
|
reconstructed_model = keras.models.load_model("my_model.keras")
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
model.predict(test_input), reconstructed_model.predict(test_input)
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user