mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix Keras h5 model storage
This commit is contained in:
parent
20f10f8fbb
commit
56825f9e7a
@ -292,8 +292,16 @@ class WeightsFileHandler(object):
|
|||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
def create_output_model(
|
||||||
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
|
model, # type: Optional[Any]
|
||||||
|
saved_path, # type: Optional[str]
|
||||||
|
framework, # type: Optional[str]
|
||||||
|
task, # type: Optional[Task]
|
||||||
|
singlefile=False, # type: bool
|
||||||
|
model_name=None, # type: Optional[str]
|
||||||
|
config_obj=None # type: Optional[Union[str, dict]]
|
||||||
|
):
|
||||||
|
# type: (...) -> str
|
||||||
if task is None:
|
if task is None:
|
||||||
return saved_path
|
return saved_path
|
||||||
|
|
||||||
@ -384,7 +392,8 @@ class WeightsFileHandler(object):
|
|||||||
|
|
||||||
trains_out_model = OutputModel(
|
trains_out_model = OutputModel(
|
||||||
task=task,
|
task=task,
|
||||||
# config_dict=config,
|
config_dict=config_obj if isinstance(config_obj, dict) else None,
|
||||||
|
config_text=config_obj if isinstance(config_obj, str) else None,
|
||||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
framework=framework,
|
framework=framework,
|
||||||
|
@ -1440,9 +1440,10 @@ class PatchKerasModelIO(object):
|
|||||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,
|
||||||
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
Functional if PatchKerasModelIO.__patched_tensorflow[3] != Functional else None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None]
|
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving, Functional, None, None]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||||
|
|
||||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||||
@ -1478,6 +1479,13 @@ class PatchKerasModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
keras_saving = None
|
keras_saving = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import tensorflow # noqa
|
||||||
|
from tensorflow.python.keras.saving import hdf5_format as keras_hdf5 # noqa
|
||||||
|
except ImportError:
|
||||||
|
keras_hdf5 = None
|
||||||
|
|
||||||
if PatchKerasModelIO.__patched_keras:
|
if PatchKerasModelIO.__patched_keras:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||||
@ -1485,14 +1493,15 @@ class PatchKerasModelIO(object):
|
|||||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,
|
||||||
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
Functional if PatchKerasModelIO.__patched_keras[3] != Functional else None,
|
||||||
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
keras_saving_legacy if PatchKerasModelIO.__patched_keras[4] != keras_saving_legacy else None,
|
||||||
|
keras_hdf5 if PatchKerasModelIO.__patched_keras[5] != keras_hdf5 else None,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
PatchKerasModelIO.__patched_tensorflow = [
|
PatchKerasModelIO.__patched_tensorflow = [
|
||||||
Network, Sequential, keras_saving, Functional, keras_saving_legacy]
|
Network, Sequential, keras_saving, Functional, keras_saving_legacy, keras_hdf5]
|
||||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None):
|
def _patch_io_calls(Network, Sequential, keras_saving, Functional, keras_saving_legacy=None, keras_hdf5=None):
|
||||||
try:
|
try:
|
||||||
if Sequential is not None:
|
if Sequential is not None:
|
||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
@ -1538,6 +1547,14 @@ class PatchKerasModelIO(object):
|
|||||||
keras_saving_legacy.load_model = _patched_call(
|
keras_saving_legacy.load_model = _patched_call(
|
||||||
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
|
keras_saving_legacy.load_model, PatchKerasModelIO._load_model)
|
||||||
|
|
||||||
|
if keras_hdf5 is not None:
|
||||||
|
keras_hdf5.save_weights_to_hdf5_group = _patched_call(
|
||||||
|
keras_hdf5.save_weights_to_hdf5_group, PatchKerasModelIO._save_weights)
|
||||||
|
keras_hdf5.load_weights_from_hdf5_group = _patched_call(
|
||||||
|
keras_hdf5.load_weights_from_hdf5_group, PatchKerasModelIO._load_weights)
|
||||||
|
keras_hdf5.load_weights_from_hdf5_group_by_name = _patched_call(
|
||||||
|
keras_hdf5.load_weights_from_hdf5_group_by_name, PatchKerasModelIO._load_weights)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@ -1549,6 +1566,8 @@ class PatchKerasModelIO(object):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# there is no actual file, so we create the OutputModel without one
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if not hasattr(self, 'trains_out_model'):
|
if not hasattr(self, 'trains_out_model'):
|
||||||
self.trains_out_model = []
|
self.trains_out_model = []
|
||||||
@ -1628,7 +1647,11 @@ class PatchKerasModelIO(object):
|
|||||||
return original_fn(self, *args, **kwargs)
|
return original_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
# get filepath
|
# get filepath
|
||||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
if self and getattr(self, 'filename', None):
|
||||||
|
filepath = getattr(self, 'filename', None)
|
||||||
|
else:
|
||||||
|
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||||
|
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
@ -1672,7 +1695,10 @@ class PatchKerasModelIO(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# get filepath
|
# get filepath
|
||||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
if self and getattr(self, 'filename', None):
|
||||||
|
filepath = getattr(self, 'filename', None)
|
||||||
|
else:
|
||||||
|
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||||
|
|
||||||
# this will already generate an output model
|
# this will already generate an output model
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -1682,40 +1708,11 @@ class PatchKerasModelIO(object):
|
|||||||
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
# we failed to convert the network to json, for some reason (most likely internal keras error)
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# check if object already has InputModel
|
if filepath:
|
||||||
if not hasattr(self, 'trains_out_model'):
|
WeightsFileHandler.create_output_model(
|
||||||
self.trains_out_model = []
|
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
|
||||||
|
config_obj=config or None, singlefile=True)
|
||||||
|
|
||||||
# check if object already has InputModel, and we this has the same filename
|
|
||||||
# (notice we Use Ptah on url for conforming)
|
|
||||||
matched = None
|
|
||||||
if self.trains_out_model:
|
|
||||||
# find the right model
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
|
|
||||||
if matched:
|
|
||||||
self.trains_out_model.remove(matched[0])
|
|
||||||
self.trains_out_model.append(matched[0])
|
|
||||||
self.trains_out_model[-1].config_dict = config
|
|
||||||
|
|
||||||
if not matched:
|
|
||||||
model_name_id = getattr(self, 'name', 'unknown')
|
|
||||||
# todo: support multiple models for the same task
|
|
||||||
self.trains_out_model.append(OutputModel(
|
|
||||||
task=PatchKerasModelIO.__main_task,
|
|
||||||
config_dict=config,
|
|
||||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
|
||||||
framework=Framework.keras,
|
|
||||||
))
|
|
||||||
# check if we have output storage
|
|
||||||
if self.trains_out_model[-1].upload_storage_uri:
|
|
||||||
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
|
|
||||||
else:
|
|
||||||
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
|
|
||||||
# if anyone asks, we were here
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
self.trains_out_model[-1]._processed = True
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user