clearml/examples/frameworks/keras/keras_v3.py

29 lines
719 B
Python

import numpy as np
import keras
from clearml import Task
def get_model():
# Create a simple model.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
return model
Task.init(project_name="examples", task_name="keras_v3")
model = get_model()
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)
model.save("my_model.keras")
reconstructed_model = keras.models.load_model("my_model.keras")
np.testing.assert_allclose(
model.predict(test_input), reconstructed_model.predict(test_input)
)