mirror of
https://github.com/clearml/clearml
synced 2025-04-26 17:30:20 +00:00
Fix Keras h5 model storage
This commit is contained in:
parent
20f10f8fbb
commit
56825f9e7a
@ -292,8 +292,16 @@ class WeightsFileHandler(object):
|
||||
return filepath
|
||||
|
||||
@staticmethod
|
||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
|
||||
def create_output_model(
|
||||
model, # type: Optional[Any]
|
||||
saved_path, # type: Optional[str]
|
||||
framework, # type: Optional[str]
|
||||
task, # type: Optional[Task]
|
||||
singlefile=False, # type: bool
|
||||
model_name=None, # type: Optional[str]
|
||||
config_obj=None # type: Optional[Union[str, dict]]
|
||||
):
|
||||
# type: (...) -> str
|
||||
if task is None:
|
||||
return saved_path
|
||||
|
||||
@ -384,7 +392,8 @@ class WeightsFileHandler(object):
|
||||
|
||||
trains_out_model = OutputModel(
|
||||
task=task,
|
||||
# config_dict=config,
|
||||
config_dict=config_obj if isinstance(config_obj, dict) else None,
|
||||
config_text=config_obj if isinstance(config_obj, str) else None,
|
||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
|
@ -1440,9 +1440,10 @@ class PatchKerasModelIO(object):
|
||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
||||
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
||||
None,
|
||||
None,
|
||||
]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None]
|
||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||
|
||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||
@ -1478,6 +1479,13 @@ class PatchKerasModelIO(object):
|
||||
except ImportError:
|
||||
keras_saving = None
|
||||
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.keras.saving import hdf5_format as keras_hdf5 # noqa
|
||||
except ImportError:
|
||||
keras_hdf5 = None
|
||||
|
||||
if PatchKerasModelIO.__patched_keras:
|
||||
PatchKerasModelIO.__patched_tensorflow = [
|
||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||
@ -1485,14 +1493,15 @@ class PatchKerasModelIO(object):
|
||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
||||
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
||||
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
||||
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
|
||||
]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_tensorflow = [
|
||||
Network, Sequential, keras_saving, Functional, keras_saving_legacy]
|
||||
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||
|
||||
@staticmethod
|
||||
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None):
|
||||
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None):
|
||||
try:
|
||||
if Sequential is not None:
|
||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||
@ -1538,6 +1547,14 @@ class PatchKerasModelIO(object):
|
||||
keras_saving_legacy.load_model = _patched_call(
|
||||
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
|
||||
|
||||
if keras_hdf5 is not None:
|
||||
keras_hdf5.save_weights_to_hdf5_group = _patched_call(
|
||||
keras_hdf5.save_weights_to_hdf5_group, PatchKerasModelIO._save_weights)
|
||||
keras_hdf5.load_weights_from_hdf5_group = _patched_call(
|
||||
keras_hdf5.load_weights_from_hdf5_group, PatchKerasModelIO._load_weights)
|
||||
keras_hdf5.load_weights_from_hdf5_group_by_name = _patched_call(
|
||||
keras_hdf5.load_weights_from_hdf5_group_by_name, PatchKerasModelIO._load_weights)
|
||||
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@ -1549,6 +1566,8 @@ class PatchKerasModelIO(object):
|
||||
return config
|
||||
|
||||
try:
|
||||
# there is no actual file, so we create the OutputModel without one
|
||||
|
||||
# check if object already has InputModel
|
||||
if not hasattr(self, 'trains_out_model'):
|
||||
self.trains_out_model = []
|
||||
@ -1628,7 +1647,11 @@ class PatchKerasModelIO(object):
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
# get filepath
|
||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||
if self and getattr(self, 'filename', None):
|
||||
filepath = getattr(self, 'filename', None)
|
||||
else:
|
||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
@ -1672,7 +1695,10 @@ class PatchKerasModelIO(object):
|
||||
|
||||
try:
|
||||
# get filepath
|
||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||
if self and getattr(self, 'filename', None):
|
||||
filepath = getattr(self, 'filename', None)
|
||||
else:
|
||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||
|
||||
# this will already generate an output model
|
||||
# noinspection PyBroadException
|
||||
@ -1682,40 +1708,11 @@ class PatchKerasModelIO(object):
|
||||
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
||||
config = {}
|
||||
|
||||
# check if object already has InputModel
|
||||
if not hasattr(self, 'trains_out_model'):
|
||||
self.trains_out_model = []
|
||||
if filepath:
|
||||
WeightsFileHandler.create_output_model(
|
||||
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
|
||||
config_obj=config or None, singlefile=True)
|
||||
|
||||
# check if object already has InputModel, and we this has the same filename
|
||||
# (notice we Use Ptah on url for conforming)
|
||||
matched = None
|
||||
if self.trains_out_model:
|
||||
# find the right model
|
||||
# noinspection PyProtectedMember
|
||||
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
|
||||
if matched:
|
||||
self.trains_out_model.remove(matched[0])
|
||||
self.trains_out_model.append(matched[0])
|
||||
self.trains_out_model[-1].config_dict = config
|
||||
|
||||
if not matched:
|
||||
model_name_id = getattr(self, 'name', 'unknown')
|
||||
# todo: support multiple models for the same task
|
||||
self.trains_out_model.append(OutputModel(
|
||||
task=PatchKerasModelIO.__main_task,
|
||||
config_dict=config,
|
||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
))
|
||||
# check if we have output storage
|
||||
if self.trains_out_model[-1].upload_storage_uri:
|
||||
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
|
||||
else:
|
||||
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
|
||||
# if anyone asks, we were here
|
||||
# noinspection PyProtectedMember
|
||||
self.trains_out_model[-1]._processed = True
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user