mirror of
https://github.com/clearml/clearml
synced 2025-05-03 20:41:00 +00:00
Fix keras reusing model object only if the filename is the same (issue #252)
This commit is contained in:
parent
c75de848f5
commit
f3638a7e5d
@ -11,6 +11,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import six
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
@ -19,7 +20,7 @@ from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel, Framework
|
||||
|
||||
try:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from google.protobuf.json_format import MessageToDict # noqa
|
||||
except ImportError:
|
||||
MessageToDict = None
|
||||
|
||||
@ -840,7 +841,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
def _patch_summary_to_event_transformer():
|
||||
if 'tensorflow' in sys.modules:
|
||||
try:
|
||||
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer
|
||||
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer # noqa
|
||||
# only patch once
|
||||
if PatchSummaryToEventTransformer.__original_getattribute is None:
|
||||
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
|
||||
@ -855,7 +856,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# only patch once
|
||||
if PatchSummaryToEventTransformer._original_add_eventT is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
|
||||
from torch.utils.tensorboard.writer import FileWriter as FileWriterT # noqa
|
||||
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
|
||||
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
|
||||
setattr(FileWriterT, 'trains', None)
|
||||
@ -870,7 +871,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# only patch once
|
||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX # noqa
|
||||
PatchSummaryToEventTransformer.__original_getattributeX = \
|
||||
SummaryToEventTransformerX.__getattribute__
|
||||
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
||||
@ -886,7 +887,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
try:
|
||||
# only patch once
|
||||
if PatchSummaryToEventTransformer._original_add_eventX is None:
|
||||
from tensorboardX.writer import FileWriter as FileWriterX
|
||||
from tensorboardX.writer import FileWriter as FileWriterX # noqa
|
||||
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
|
||||
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
|
||||
setattr(FileWriterX, 'trains', None)
|
||||
@ -1041,14 +1042,14 @@ class PatchModelCheckPointCallback(object):
|
||||
callbacks = None
|
||||
if is_keras:
|
||||
try:
|
||||
import keras.callbacks as callbacks # noqa: F401
|
||||
import keras.callbacks as callbacks # noqa
|
||||
except ImportError:
|
||||
is_keras = False
|
||||
if not is_keras and is_tf_keras:
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401
|
||||
import tensorflow.python.keras.callbacks as callbacks # noqa: F811
|
||||
import tensorflow # noqa
|
||||
import tensorflow.python.keras.callbacks as callbacks # noqa
|
||||
except ImportError:
|
||||
is_tf_keras = False
|
||||
callbacks = None
|
||||
@ -1129,8 +1130,8 @@ class PatchTensorFlowEager(object):
|
||||
if 'tensorflow' in sys.modules:
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401
|
||||
from tensorflow.python.ops import gen_summary_ops # noqa: F401
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.ops import gen_summary_ops # noqa
|
||||
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
|
||||
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
|
||||
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
|
||||
@ -1160,12 +1161,12 @@ class PatchTensorFlowEager(object):
|
||||
# check if we are in eager mode, let's get the global context lopdir
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import context # noqa
|
||||
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
|
||||
except Exception:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from tensorflow.python.ops.summary_ops_v2 import _summary_state
|
||||
from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
|
||||
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
|
||||
except Exception:
|
||||
logdir = None
|
||||
@ -1300,19 +1301,19 @@ class PatchKerasModelIO(object):
|
||||
def _patch_model_checkpoint():
|
||||
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
|
||||
try:
|
||||
from keras.engine.network import Network
|
||||
from keras.engine.network import Network # noqa
|
||||
except ImportError:
|
||||
Network = None
|
||||
try:
|
||||
from keras.engine.functional import Functional
|
||||
from keras.engine.functional import Functional # noqa
|
||||
except ImportError:
|
||||
Functional = None
|
||||
try:
|
||||
from keras.engine.sequential import Sequential
|
||||
from keras.engine.sequential import Sequential # noqa
|
||||
except ImportError:
|
||||
Sequential = None
|
||||
try:
|
||||
from keras import models as keras_saving
|
||||
from keras import models as keras_saving # noqa
|
||||
except ImportError:
|
||||
keras_saving = None
|
||||
# check that we are not patching anything twice
|
||||
@ -1329,26 +1330,26 @@ class PatchKerasModelIO(object):
|
||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401, F811
|
||||
from tensorflow.python.keras.engine.network import Network
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.keras.engine.network import Network # noqa
|
||||
except ImportError:
|
||||
Network = None
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401, F811
|
||||
from tensorflow.python.keras.engine.functional import Functional
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.keras.engine.functional import Functional # noqa
|
||||
except ImportError:
|
||||
Functional = None
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401, F811
|
||||
from tensorflow.python.keras.engine.sequential import Sequential
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.keras.engine.sequential import Sequential # noqa
|
||||
except ImportError:
|
||||
Sequential = None
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401, F811
|
||||
from tensorflow.python.keras import models as keras_saving
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.keras import models as keras_saving # noqa
|
||||
except ImportError:
|
||||
keras_saving = None
|
||||
|
||||
@ -1387,7 +1388,8 @@ class PatchKerasModelIO(object):
|
||||
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
||||
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
||||
elif Functional is not None:
|
||||
Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config)
|
||||
Functional._updated_config = _patched_call(
|
||||
Functional._updated_config, PatchKerasModelIO._updated_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
|
||||
@ -1414,21 +1416,21 @@ class PatchKerasModelIO(object):
|
||||
try:
|
||||
# check if object already has InputModel
|
||||
if not hasattr(self, 'trains_out_model'):
|
||||
self.trains_out_model = None
|
||||
self.trains_out_model = []
|
||||
|
||||
# check if object already has InputModel
|
||||
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
|
||||
if self.trains_out_model is not None:
|
||||
self.trains_out_model.config_dict = config
|
||||
if self.trains_out_model:
|
||||
self.trains_out_model[-1].config_dict = config
|
||||
else:
|
||||
# todo: support multiple models for the same task
|
||||
self.trains_out_model = OutputModel(
|
||||
self.trains_out_model.append(OutputModel(
|
||||
task=PatchKerasModelIO.__main_task,
|
||||
config_dict=config,
|
||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
)
|
||||
))
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@ -1512,11 +1514,14 @@ class PatchKerasModelIO(object):
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, self, *args, **kwargs):
|
||||
if hasattr(self, 'trains_out_model'):
|
||||
self.trains_out_model._processed = False
|
||||
if hasattr(self, 'trains_out_model') and self.trains_out_model:
|
||||
# noinspection PyProtectedMember
|
||||
self.trains_out_model[-1]._processed = False
|
||||
original_fn(self, *args, **kwargs)
|
||||
# no need to specially call, because the original save uses "save_model" which we overload
|
||||
if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed:
|
||||
# noinspection PyProtectedMember
|
||||
if not hasattr(self, 'trains_out_model') or not self.trains_out_model or \
|
||||
not hasattr(self.trains_out_model[-1], '_processed') or not self.trains_out_model[-1]._processed:
|
||||
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -1544,28 +1549,38 @@ class PatchKerasModelIO(object):
|
||||
|
||||
# check if object already has InputModel
|
||||
if not hasattr(self, 'trains_out_model'):
|
||||
self.trains_out_model = None
|
||||
self.trains_out_model = []
|
||||
|
||||
# check if object already has InputModel
|
||||
if self.trains_out_model is not None:
|
||||
self.trains_out_model.config_dict = config
|
||||
else:
|
||||
# check if object already has InputModel, and we this has the same filename
|
||||
# (notice we Use Ptah on url for conforming)
|
||||
matched = None
|
||||
if self.trains_out_model:
|
||||
# find the right model
|
||||
# noinspection PyProtectedMember
|
||||
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
|
||||
if matched:
|
||||
self.trains_out_model.remove(matched[0])
|
||||
self.trains_out_model.append(matched[0])
|
||||
self.trains_out_model[-1].config_dict = config
|
||||
|
||||
if not matched:
|
||||
model_name_id = getattr(self, 'name', 'unknown')
|
||||
# todo: support multiple models for the same task
|
||||
self.trains_out_model = OutputModel(
|
||||
self.trains_out_model.append(OutputModel(
|
||||
task=PatchKerasModelIO.__main_task,
|
||||
config_dict=config,
|
||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
)
|
||||
))
|
||||
# check if we have output storage
|
||||
if self.trains_out_model.upload_storage_uri:
|
||||
self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False)
|
||||
if self.trains_out_model[-1].upload_storage_uri:
|
||||
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
|
||||
else:
|
||||
self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath)
|
||||
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
|
||||
# if anyone asks, we were here
|
||||
self.trains_out_model._processed = True
|
||||
# noinspection PyProtectedMember
|
||||
self.trains_out_model[-1]._processed = True
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@ -1624,9 +1639,9 @@ class PatchTensorflowModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow
|
||||
import tensorflow # noqa
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tensorflow.python.training.saver import Saver
|
||||
from tensorflow.python.training.saver import Saver # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
||||
@ -1645,18 +1660,18 @@ class PatchTensorflowModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow # noqa: F811
|
||||
from tensorflow.saved_model import save
|
||||
import tensorflow # noqa
|
||||
from tensorflow.saved_model import save # noqa
|
||||
# actual import
|
||||
from tensorflow.python.saved_model import save as saved_model
|
||||
from tensorflow.python.saved_model import save as saved_model # noqa
|
||||
except ImportError:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
from tensorflow.saved_model.experimental import save # noqa: F401
|
||||
import tensorflow # noqa
|
||||
from tensorflow.saved_model.experimental import save # noqa
|
||||
# actual import
|
||||
import tensorflow.saved_model.experimental as saved_model
|
||||
import tensorflow.saved_model.experimental as saved_model # noqa
|
||||
except ImportError:
|
||||
saved_model = None
|
||||
except Exception:
|
||||
@ -1671,11 +1686,11 @@ class PatchTensorflowModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow # noqa: F811
|
||||
import tensorflow # noqa
|
||||
# actual import
|
||||
from tensorflow.saved_model import load # noqa: F401
|
||||
from tensorflow.saved_model import load # noqa
|
||||
# noinspection PyUnresolvedReferences
|
||||
import tensorflow.saved_model as saved_model_load
|
||||
import tensorflow.saved_model as saved_model_load # noqa
|
||||
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
@ -1685,10 +1700,10 @@ class PatchTensorflowModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow # noqa: F811
|
||||
import tensorflow # noqa
|
||||
# actual import
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tensorflow.saved_model import loader as loader1
|
||||
from tensorflow.saved_model import loader as loader1 # noqa
|
||||
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
@ -1698,10 +1713,10 @@ class PatchTensorflowModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow # noqa: F811
|
||||
import tensorflow # noqa
|
||||
# actual import
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tensorflow.compat.v1.saved_model import loader as loader2
|
||||
from tensorflow.compat.v1.saved_model import loader as loader2 # noqa
|
||||
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
@ -1710,8 +1725,8 @@ class PatchTensorflowModelIO(object):
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import tensorflow # noqa: F401, F811
|
||||
from tensorflow.train import Checkpoint
|
||||
import tensorflow # noqa
|
||||
from tensorflow.train import Checkpoint # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
||||
@ -1861,8 +1876,8 @@ class PatchTensorflow2ModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow # noqa: F401
|
||||
from tensorflow.python.training.tracking import util
|
||||
import tensorflow # noqa
|
||||
from tensorflow.python.training.tracking import util # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||
|
@ -945,6 +945,7 @@ class OutputModel(BaseModel):
|
||||
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||
|
||||
self._model_local_filename = None
|
||||
self._last_uploaded_url = None
|
||||
self._base_model = None
|
||||
# noinspection PyProtectedMember
|
||||
self._floating_data = create_dummy_model(
|
||||
@ -1205,6 +1206,8 @@ class OutputModel(BaseModel):
|
||||
else:
|
||||
output_uri = None
|
||||
|
||||
self._last_uploaded_url = output_uri
|
||||
|
||||
if is_package:
|
||||
self._set_package_tag()
|
||||
|
||||
@ -1433,6 +1436,9 @@ class OutputModel(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _get_last_uploaded_filename(self):
|
||||
return Path(self._last_uploaded_url or self.url).name
|
||||
|
||||
|
||||
class Waitable(object):
|
||||
def wait(self, *_, **__):
|
||||
|
Loading…
Reference in New Issue
Block a user