Fix Keras h5 model storage

This commit is contained in:
allegroai 2021-03-18 09:49:53 +02:00
parent 20f10f8fbb
commit 56825f9e7a
2 changed files with 47 additions and 41 deletions

View File

@ -292,8 +292,16 @@ class WeightsFileHandler(object):
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
def create_output_model(
model, # type: Optional[Any]
saved_path, # type: Optional[str]
framework, # type: Optional[str]
task, # type: Optional[Task]
singlefile=False, # type: bool
model_name=None, # type: Optional[str]
config_obj=None # type: Optional[Union[str, dict]]
):
# type: (...) -> str
if task is None:
return saved_path
@ -384,7 +392,8 @@ class WeightsFileHandler(object):
trains_out_model = OutputModel(
task=task,
# config_dict=config,
config_dict=config_obj if isinstance(config_obj, dict) else None,
config_text=config_obj if isinstance(config_obj, str) else None,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework,

View File

@ -1440,9 +1440,10 @@ class PatchKerasModelIO(object):
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
None,
None,
]
else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None]
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
@ -1478,6 +1479,13 @@ class PatchKerasModelIO(object):
except ImportError:
keras_saving = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
from tensorflow.python.keras.saving import hdf5_format as keras_hdf5 # noqa
except ImportError:
keras_hdf5 = None
if PatchKerasModelIO.__patched_keras:
PatchKerasModelIO.__patched_tensorflow = [
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
@ -1485,14 +1493,15 @@ class PatchKerasModelIO(object):
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
]
else:
PatchKerasModelIO.__patched_tensorflow = [
Network, Sequential, keras_saving, Functional, keras_saving_legacy]
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@staticmethod
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None):
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None):
try:
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
@ -1538,6 +1547,14 @@ class PatchKerasModelIO(object):
keras_saving_legacy.load_model = _patched_call(
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
if keras_hdf5 is not None:
keras_hdf5.save_weights_to_hdf5_group = _patched_call(
keras_hdf5.save_weights_to_hdf5_group, PatchKerasModelIO._save_weights)
keras_hdf5.load_weights_from_hdf5_group = _patched_call(
keras_hdf5.load_weights_from_hdf5_group, PatchKerasModelIO._load_weights)
keras_hdf5.load_weights_from_hdf5_group_by_name = _patched_call(
keras_hdf5.load_weights_from_hdf5_group_by_name, PatchKerasModelIO._load_weights)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1549,6 +1566,8 @@ class PatchKerasModelIO(object):
return config
try:
# there is no actual file, so we create the OutputModel without one
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = []
@ -1628,7 +1647,11 @@ class PatchKerasModelIO(object):
return original_fn(self, *args, **kwargs)
# get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
if self and getattr(self, 'filename', None):
filepath = getattr(self, 'filename', None)
else:
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
# Hack: disabled
if False and running_remotely():
# register/load model weights
@ -1672,7 +1695,10 @@ class PatchKerasModelIO(object):
try:
# get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
if self and getattr(self, 'filename', None):
filepath = getattr(self, 'filename', None)
else:
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
# this will already generate an output model
# noinspection PyBroadException
@ -1682,40 +1708,11 @@ class PatchKerasModelIO(object):
# we failed to convert the network to json, for some reason (most likely internal keras error)
config = {}
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = []
if filepath:
WeightsFileHandler.create_output_model(
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
config_obj=config or None, singlefile=True)
# check if object already has InputModel, and we this has the same filename
# (notice we Use Ptah on url for conforming)
matched = None
if self.trains_out_model:
# find the right model
# noinspection PyProtectedMember
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
if matched:
self.trains_out_model.remove(matched[0])
self.trains_out_model.append(matched[0])
self.trains_out_model[-1].config_dict = config
if not matched:
model_name_id = getattr(self, 'name', 'unknown')
# todo: support multiple models for the same task
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
))
# check if we have output storage
if self.trains_out_model[-1].upload_storage_uri:
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
else:
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
# if anyone asks, we were here
# noinspection PyProtectedMember
self.trains_out_model[-1]._processed = True
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))