mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
29 lines
719 B
Python
29 lines
719 B
Python
import numpy as np
|
|
import keras
|
|
from clearml import Task
|
|
|
|
|
|
def get_model():
|
|
# Create a simple model.
|
|
inputs = keras.Input(shape=(32,))
|
|
outputs = keras.layers.Dense(1)(inputs)
|
|
model = keras.Model(inputs, outputs)
|
|
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
|
|
return model
|
|
|
|
Task.init(project_name="examples", task_name="keras_v3")
|
|
|
|
model = get_model()
|
|
|
|
test_input = np.random.random((128, 32))
|
|
test_target = np.random.random((128, 1))
|
|
model.fit(test_input, test_target)
|
|
|
|
model.save("my_model.keras")
|
|
|
|
reconstructed_model = keras.models.load_model("my_model.keras")
|
|
|
|
np.testing.assert_allclose(
|
|
model.predict(test_input), reconstructed_model.predict(test_input)
|
|
)
|