Fix Keras h5 model storage

This commit is contained in:
allegroai 2021-03-18 09:49:53 +02:00
parent 20f10f8fbb
commit 56825f9e7a
2 changed files with 47 additions and 41 deletions

View File

@ -292,8 +292,16 @@ class WeightsFileHandler(object):
return filepath return filepath
@staticmethod @staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): def create_output_model(
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str model, # type: Optional[Any]
saved_path, # type: Optional[str]
framework, # type: Optional[str]
task, # type: Optional[Task]
singlefile=False, # type: bool
model_name=None, # type: Optional[str]
config_obj=None # type: Optional[Union[str, dict]]
):
# type: (...) -> str
if task is None: if task is None:
return saved_path return saved_path
@ -384,7 +392,8 @@ class WeightsFileHandler(object):
trains_out_model = OutputModel( trains_out_model = OutputModel(
task=task, task=task,
# config_dict=config, config_dict=config_obj if isinstance(config_obj, dict) else None,
config_text=config_obj if isinstance(config_obj, str) else None,
name=(task.name + ' - ' + model_name) if model_name else None, name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, framework=framework,

View File

@ -1440,9 +1440,10 @@ class PatchKerasModelIO(object):
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
None, None,
None,
] ]
else: else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None] PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
@ -1478,6 +1479,13 @@ class PatchKerasModelIO(object):
except ImportError: except ImportError:
keras_saving = None keras_saving = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa
from tensorflow.python.keras.saving import hdf5_format as keras_hdf5 # noqa
except ImportError:
keras_hdf5 = None
if PatchKerasModelIO.__patched_keras: if PatchKerasModelIO.__patched_keras:
PatchKerasModelIO.__patched_tensorflow = [ PatchKerasModelIO.__patched_tensorflow = [
Network if PatchKerasModelIO.__patched_keras[0] != Network else None, Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
@ -1485,14 +1493,15 @@ class PatchKerasModelIO(object):
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None, keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
] ]
else: else:
PatchKerasModelIO.__patched_tensorflow = [ PatchKerasModelIO.__patched_tensorflow = [
Network, Sequential, keras_saving, Functional, keras_saving_legacy] Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@staticmethod @staticmethod
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None): def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None):
try: try:
if Sequential is not None: if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config, Sequential._updated_config = _patched_call(Sequential._updated_config,
@ -1538,6 +1547,14 @@ class PatchKerasModelIO(object):
keras_saving_legacy.load_model = _patched_call( keras_saving_legacy.load_model = _patched_call(
keras_saving_legacy.load_model, PatchKerasModelIO._load_model) keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
if keras_hdf5 is not None:
keras_hdf5.save_weights_to_hdf5_group = _patched_call(
keras_hdf5.save_weights_to_hdf5_group, PatchKerasModelIO._save_weights)
keras_hdf5.load_weights_from_hdf5_group = _patched_call(
keras_hdf5.load_weights_from_hdf5_group, PatchKerasModelIO._load_weights)
keras_hdf5.load_weights_from_hdf5_group_by_name = _patched_call(
keras_hdf5.load_weights_from_hdf5_group_by_name, PatchKerasModelIO._load_weights)
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1549,6 +1566,8 @@ class PatchKerasModelIO(object):
return config return config
try: try:
# there is no actual file, so we create the OutputModel without one
# check if object already has InputModel # check if object already has InputModel
if not hasattr(self, 'trains_out_model'): if not hasattr(self, 'trains_out_model'):
self.trains_out_model = [] self.trains_out_model = []
@ -1628,7 +1647,11 @@ class PatchKerasModelIO(object):
return original_fn(self, *args, **kwargs) return original_fn(self, *args, **kwargs)
# get filepath # get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] if self and getattr(self, 'filename', None):
filepath = getattr(self, 'filename', None)
else:
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
# register/load model weights # register/load model weights
@ -1672,7 +1695,10 @@ class PatchKerasModelIO(object):
try: try:
# get filepath # get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] if self and getattr(self, 'filename', None):
filepath = getattr(self, 'filename', None)
else:
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
# this will already generate an output model # this will already generate an output model
# noinspection PyBroadException # noinspection PyBroadException
@ -1682,40 +1708,11 @@ class PatchKerasModelIO(object):
# we failed to convert the network to json, for some reason (most likely internal keras error) # we failed to convert the network to json, for some reason (most likely internal keras error)
config = {} config = {}
# check if object already has InputModel if filepath:
if not hasattr(self, 'trains_out_model'): WeightsFileHandler.create_output_model(
self.trains_out_model = [] self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
config_obj=config or None, singlefile=True)
# check if object already has InputModel, and we this has the same filename
# (notice we Use Ptah on url for conforming)
matched = None
if self.trains_out_model:
# find the right model
# noinspection PyProtectedMember
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
if matched:
self.trains_out_model.remove(matched[0])
self.trains_out_model.append(matched[0])
self.trains_out_model[-1].config_dict = config
if not matched:
model_name_id = getattr(self, 'name', 'unknown')
# todo: support multiple models for the same task
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
))
# check if we have output storage
if self.trains_out_model[-1].upload_storage_uri:
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
else:
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
# if anyone asks, we were here
# noinspection PyProtectedMember
self.trains_out_model[-1]._processed = True
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))