Fix Keras model config serialization in PatchKerasModelIO (#616)

Issue: #614

In case the model contains some Lambda layer the Model._updated_config()
function may return objects that are not serializable, such as function
objects (eg. `K.mean`). Its serialization to JSON then fails.

This change adds proper serialization (as done in Model.to_json()) to
two places where the model configuration is passed to OutputModel for
being serialized: _updated_config(), _update_outputmodel().

The return value of patching _updated_config() preserves the
non-serializable objects as some other code may depend on them.
This commit is contained in:
Bohumír Zámečník 2022-04-19 17:58:15 +02:00 committed by GitHub
parent 81de18dbce
commit 90d060dd7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import base64
import json
import os
import sys
import threading
@ -1695,6 +1696,14 @@ class PatchKerasModelIO(object):
return config
try:
# noinspection PyBroadException
try:
from tensorflow.python.util.serialization import get_json_type
# Model._updated_config() may contain non-serializable objects
safe_config = json.loads(json.dumps(config, default=get_json_type))
except Exception:
safe_config = config
# there is no actual file, so we create the OutputModel without one
# check if object already has InputModel
@ -1702,14 +1711,14 @@ class PatchKerasModelIO(object):
self.trains_out_model = []
# check if object already has InputModel
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
model_name_id = safe_config.get('name', getattr(self, 'name', 'unknown'))
if self.trains_out_model:
self.trains_out_model[-1].config_dict = config
self.trains_out_model[-1].config_dict = safe_config
else:
# todo: support multiple models for the same task
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
config_dict=safe_config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
@ -1832,7 +1841,10 @@ class PatchKerasModelIO(object):
# this will already generate an output model
# noinspection PyBroadException
try:
config = self._updated_config()
from tensorflow.python.util.serialization import get_json_type
# Model._updated_config() may contain non-serializable objects
unsafe_config = self._updated_config()
config = json.loads(json.dumps(unsafe_config, default=get_json_type))
except Exception:
# we failed to convert the network to json, for some reason (most likely internal keras error)
config = {}