mirror of
https://github.com/clearml/clearml
synced 2025-02-07 05:18:50 +00:00
Fix Keras model config serialization in PatchKerasModelIO (#616)
Issue: #614 In case the model contains some Lambda layer the Model._updated_config() function may return objects that are not serializable, such as function objects (eg. `K.mean`). Its serialization to JSON then fails. This change adds proper serialization (as done in Model.to_json()) to two places where the model configuration is passed to OutputModel for being serialized: _updated_config(), _update_outputmodel(). The return value of patching _updated_config() preserves the non-serializable objects as some other code may depend on them.
This commit is contained in:
parent
81de18dbce
commit
90d060dd7e
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -1695,6 +1696,14 @@ class PatchKerasModelIO(object):
|
||||
return config
|
||||
|
||||
try:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from tensorflow.python.util.serialization import get_json_type
|
||||
# Model._updated_config() may contain non-serializable objects
|
||||
safe_config = json.loads(json.dumps(config, default=get_json_type))
|
||||
except Exception:
|
||||
safe_config = config
|
||||
|
||||
# there is no actual file, so we create the OutputModel without one
|
||||
|
||||
# check if object already has InputModel
|
||||
@ -1702,14 +1711,14 @@ class PatchKerasModelIO(object):
|
||||
self.trains_out_model = []
|
||||
|
||||
# check if object already has InputModel
|
||||
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
|
||||
model_name_id = safe_config.get('name', getattr(self, 'name', 'unknown'))
|
||||
if self.trains_out_model:
|
||||
self.trains_out_model[-1].config_dict = config
|
||||
self.trains_out_model[-1].config_dict = safe_config
|
||||
else:
|
||||
# todo: support multiple models for the same task
|
||||
self.trains_out_model.append(OutputModel(
|
||||
task=PatchKerasModelIO.__main_task,
|
||||
config_dict=config,
|
||||
config_dict=safe_config,
|
||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
@ -1832,7 +1841,10 @@ class PatchKerasModelIO(object):
|
||||
# this will already generate an output model
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
config = self._updated_config()
|
||||
from tensorflow.python.util.serialization import get_json_type
|
||||
# Model._updated_config() may contain non-serializable objects
|
||||
unsafe_config = self._updated_config()
|
||||
config = json.loads(json.dumps(unsafe_config, default=get_json_type))
|
||||
except Exception:
|
||||
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
||||
config = {}
|
||||
|
Loading…
Reference in New Issue
Block a user