From 56825f9e7af26a43da0f2f12454e23a43e995a25 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 18 Mar 2021 09:49:53 +0200 Subject: [PATCH] Fix Keras h5 model storage --- clearml/binding/frameworks/__init__.py | 15 +++- clearml/binding/frameworks/tensorflow_bind.py | 73 +++++++++---------- 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/clearml/binding/frameworks/__init__.py b/clearml/binding/frameworks/__init__.py index 2a24ab83..81fb1f36 100644 --- a/clearml/binding/frameworks/__init__.py +++ b/clearml/binding/frameworks/__init__.py @@ -292,8 +292,16 @@ class WeightsFileHandler(object): return filepath @staticmethod - def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): - # type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str + def create_output_model( + model, # type: Optional[Any] + saved_path, # type: Optional[str] + framework, # type: Optional[str] + task, # type: Optional[Task] + singlefile=False, # type: bool + model_name=None, # type: Optional[str] + config_obj=None # type: Optional[Union[str, dict]] + ): + # type: (...) -> str if task is None: return saved_path @@ -384,7 +392,8 @@ class WeightsFileHandler(object): trains_out_model = OutputModel( task=task, - # config_dict=config, + config_dict=config_obj if isinstance(config_obj, dict) else None, + config_text=config_obj if isinstance(config_obj, str) else None, name=(task.name + ' - ' + model_name) if model_name else None, label_enumeration=task.get_labels_enumeration(), framework=framework, diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index f7ae2822..c854b159 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -1440,9 +1440,10 @@ class PatchKerasModelIO(object): keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None, None, + None, ] else: - PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None] + PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: @@ -1478,6 +1479,13 @@ class PatchKerasModelIO(object): except ImportError: keras_saving = None + try: + # hack: make sure tensorflow.__init__ is called + import tensorflow # noqa + from tensorflow.python.keras.saving import hdf5_format as keras_hdf5 # noqa + except ImportError: + keras_hdf5 = None + if PatchKerasModelIO.__patched_keras: PatchKerasModelIO.__patched_tensorflow = [ Network if PatchKerasModelIO.__patched_keras[0] != Network else None, @@ -1485,14 +1493,15 @@ class PatchKerasModelIO(object): keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None, keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None, + keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None, ] else: PatchKerasModelIO.__patched_tensorflow = [ - Network, Sequential, keras_saving, Functional, keras_saving_legacy] + Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) @staticmethod - def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None): + def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None): try: if Sequential is not None: Sequential._updated_config = _patched_call(Sequential._updated_config, @@ -1538,6 +1547,14 @@ class PatchKerasModelIO(object): keras_saving_legacy.load_model = _patched_call( keras_saving_legacy.load_model, PatchKerasModelIO._load_model) + if keras_hdf5 is not None: + keras_hdf5.save_weights_to_hdf5_group = _patched_call( + keras_hdf5.save_weights_to_hdf5_group, PatchKerasModelIO._save_weights) + keras_hdf5.load_weights_from_hdf5_group = _patched_call( + keras_hdf5.load_weights_from_hdf5_group, PatchKerasModelIO._load_weights) + keras_hdf5.load_weights_from_hdf5_group_by_name = _patched_call( + keras_hdf5.load_weights_from_hdf5_group_by_name, PatchKerasModelIO._load_weights) + except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1549,6 +1566,8 @@ class PatchKerasModelIO(object): return config try: + # there is no actual file, so we create the OutputModel without one + # check if object already has InputModel if not hasattr(self, 'trains_out_model'): self.trains_out_model = [] @@ -1628,7 +1647,11 @@ class PatchKerasModelIO(object): return original_fn(self, *args, **kwargs) # get filepath - filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] + if self and getattr(self, 'filename', None): + filepath = getattr(self, 'filename', None) + else: + filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] + # Hack: disabled if False and running_remotely(): # register/load model weights @@ -1672,7 +1695,10 @@ class PatchKerasModelIO(object): try: # get filepath - filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] + if self and getattr(self, 'filename', None): + filepath = getattr(self, 'filename', None) + else: + filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] # this will already generate an output model # noinspection PyBroadException @@ -1682,40 +1708,11 @@ class PatchKerasModelIO(object): # we failed to convert the network to json, for some reason (most likely internal keras error) config = {} - # check if object already has InputModel - if not hasattr(self, 'trains_out_model'): - self.trains_out_model = [] + if filepath: + WeightsFileHandler.create_output_model( + self, filepath, Framework.keras, PatchKerasModelIO.__main_task, + config_obj=config or None, singlefile=True) - # check if object already has InputModel, and we this has the same filename - # (notice we Use Ptah on url for conforming) - matched = None - if self.trains_out_model: - # find the right model - # noinspection PyProtectedMember - matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name] - if matched: - self.trains_out_model.remove(matched[0]) - self.trains_out_model.append(matched[0]) - self.trains_out_model[-1].config_dict = config - - if not matched: - model_name_id = getattr(self, 'name', 'unknown') - # todo: support multiple models for the same task - self.trains_out_model.append(OutputModel( - task=PatchKerasModelIO.__main_task, - config_dict=config, - name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, - label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), - framework=Framework.keras, - )) - # check if we have output storage - if self.trains_out_model[-1].upload_storage_uri: - self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False) - else: - self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath) - # if anyone asks, we were here - # noinspection PyProtectedMember - self.trains_out_model[-1]._processed = True except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))