mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Fix Keras model config serialization in PatchKerasModelIO (#616)
Issue: #614 In case the model contains some Lambda layer the Model._updated_config() function may return objects that are not serializable, such as function objects (eg. `K.mean`). Its serialization to JSON then fails. This change adds proper serialization (as done in Model.to_json()) to two places where the model configuration is passed to OutputModel for being serialized: _updated_config(), _update_outputmodel(). The return value of patching _updated_config() preserves the non-serializable objects as some other code may depend on them.
This commit is contained in:
parent
81de18dbce
commit
90d060dd7e
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -1695,6 +1696,14 @@ class PatchKerasModelIO(object):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from tensorflow.python.util.serialization import get_json_type
|
||||||
|
# Model._updated_config() may contain non-serializable objects
|
||||||
|
safe_config = json.loads(json.dumps(config, default=get_json_type))
|
||||||
|
except Exception:
|
||||||
|
safe_config = config
|
||||||
|
|
||||||
# there is no actual file, so we create the OutputModel without one
|
# there is no actual file, so we create the OutputModel without one
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
@ -1702,14 +1711,14 @@ class PatchKerasModelIO(object):
|
|||||||
self.trains_out_model = []
|
self.trains_out_model = []
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
|
model_name_id = safe_config.get('name', getattr(self, 'name', 'unknown'))
|
||||||
if self.trains_out_model:
|
if self.trains_out_model:
|
||||||
self.trains_out_model[-1].config_dict = config
|
self.trains_out_model[-1].config_dict = safe_config
|
||||||
else:
|
else:
|
||||||
# todo: support multiple models for the same task
|
# todo: support multiple models for the same task
|
||||||
self.trains_out_model.append(OutputModel(
|
self.trains_out_model.append(OutputModel(
|
||||||
task=PatchKerasModelIO.__main_task,
|
task=PatchKerasModelIO.__main_task,
|
||||||
config_dict=config,
|
config_dict=safe_config,
|
||||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
@ -1832,7 +1841,10 @@ class PatchKerasModelIO(object):
|
|||||||
# this will already generate an output model
|
# this will already generate an output model
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
config = self._updated_config()
|
from tensorflow.python.util.serialization import get_json_type
|
||||||
|
# Model._updated_config() may contain non-serializable objects
|
||||||
|
unsafe_config = self._updated_config()
|
||||||
|
config = json.loads(json.dumps(unsafe_config, default=get_json_type))
|
||||||
except Exception:
|
except Exception:
|
||||||
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
||||||
config = {}
|
config = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user