Fix keras reusing model object only if the filename is the same (issue #252)

This commit is contained in:
allegroai 2020-11-25 14:55:45 +02:00
parent c75de848f5
commit f3638a7e5d
2 changed files with 84 additions and 63 deletions

View File

@ -11,6 +11,7 @@ from typing import Any
import numpy as np import numpy as np
import six import six
from PIL import Image from PIL import Image
from pathlib2 import Path
from ...debugging.log import LoggerRoot from ...debugging.log import LoggerRoot
from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..frameworks import _patched_call, WeightsFileHandler, _Empty
@ -19,7 +20,7 @@ from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework from ...model import InputModel, OutputModel, Framework
try: try:
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict # noqa
except ImportError: except ImportError:
MessageToDict = None MessageToDict = None
@ -840,7 +841,7 @@ class PatchSummaryToEventTransformer(object):
def _patch_summary_to_event_transformer(): def _patch_summary_to_event_transformer():
if 'tensorflow' in sys.modules: if 'tensorflow' in sys.modules:
try: try:
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer from tensorflow.python.summary.writer.writer import SummaryToEventTransformer # noqa
# only patch once # only patch once
if PatchSummaryToEventTransformer.__original_getattribute is None: if PatchSummaryToEventTransformer.__original_getattribute is None:
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__ PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
@ -855,7 +856,7 @@ class PatchSummaryToEventTransformer(object):
# only patch once # only patch once
if PatchSummaryToEventTransformer._original_add_eventT is None: if PatchSummaryToEventTransformer._original_add_eventT is None:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from torch.utils.tensorboard.writer import FileWriter as FileWriterT from torch.utils.tensorboard.writer import FileWriter as FileWriterT # noqa
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
setattr(FileWriterT, 'trains', None) setattr(FileWriterT, 'trains', None)
@ -870,7 +871,7 @@ class PatchSummaryToEventTransformer(object):
# only patch once # only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None: if PatchSummaryToEventTransformer.__original_getattributeX is None:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX # noqa
PatchSummaryToEventTransformer.__original_getattributeX = \ PatchSummaryToEventTransformer.__original_getattributeX = \
SummaryToEventTransformerX.__getattribute__ SummaryToEventTransformerX.__getattribute__
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
@ -886,7 +887,7 @@ class PatchSummaryToEventTransformer(object):
try: try:
# only patch once # only patch once
if PatchSummaryToEventTransformer._original_add_eventX is None: if PatchSummaryToEventTransformer._original_add_eventX is None:
from tensorboardX.writer import FileWriter as FileWriterX from tensorboardX.writer import FileWriter as FileWriterX # noqa
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
setattr(FileWriterX, 'trains', None) setattr(FileWriterX, 'trains', None)
@ -1041,14 +1042,14 @@ class PatchModelCheckPointCallback(object):
callbacks = None callbacks = None
if is_keras: if is_keras:
try: try:
import keras.callbacks as callbacks # noqa: F401 import keras.callbacks as callbacks # noqa
except ImportError: except ImportError:
is_keras = False is_keras = False
if not is_keras and is_tf_keras: if not is_keras and is_tf_keras:
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401 import tensorflow # noqa
import tensorflow.python.keras.callbacks as callbacks # noqa: F811 import tensorflow.python.keras.callbacks as callbacks # noqa
except ImportError: except ImportError:
is_tf_keras = False is_tf_keras = False
callbacks = None callbacks = None
@ -1129,8 +1130,8 @@ class PatchTensorFlowEager(object):
if 'tensorflow' in sys.modules: if 'tensorflow' in sys.modules:
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401 import tensorflow # noqa
from tensorflow.python.ops import gen_summary_ops # noqa: F401 from tensorflow.python.ops import gen_summary_ops # noqa
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
@ -1160,12 +1161,12 @@ class PatchTensorFlowEager(object):
# check if we are in eager mode, let's get the global context lopdir # check if we are in eager mode, let's get the global context lopdir
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from tensorflow.python.eager import context from tensorflow.python.eager import context # noqa
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir') logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
except Exception: except Exception:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from tensorflow.python.ops.summary_ops_v2 import _summary_state from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir') logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
except Exception: except Exception:
logdir = None logdir = None
@ -1300,19 +1301,19 @@ class PatchKerasModelIO(object):
def _patch_model_checkpoint(): def _patch_model_checkpoint():
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras: if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
try: try:
from keras.engine.network import Network from keras.engine.network import Network # noqa
except ImportError: except ImportError:
Network = None Network = None
try: try:
from keras.engine.functional import Functional from keras.engine.functional import Functional # noqa
except ImportError: except ImportError:
Functional = None Functional = None
try: try:
from keras.engine.sequential import Sequential from keras.engine.sequential import Sequential # noqa
except ImportError: except ImportError:
Sequential = None Sequential = None
try: try:
from keras import models as keras_saving from keras import models as keras_saving # noqa
except ImportError: except ImportError:
keras_saving = None keras_saving = None
# check that we are not patching anything twice # check that we are not patching anything twice
@ -1329,26 +1330,26 @@ class PatchKerasModelIO(object):
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811 import tensorflow # noqa
from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.engine.network import Network # noqa
except ImportError: except ImportError:
Network = None Network = None
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811 import tensorflow # noqa
from tensorflow.python.keras.engine.functional import Functional from tensorflow.python.keras.engine.functional import Functional # noqa
except ImportError: except ImportError:
Functional = None Functional = None
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811 import tensorflow # noqa
from tensorflow.python.keras.engine.sequential import Sequential from tensorflow.python.keras.engine.sequential import Sequential # noqa
except ImportError: except ImportError:
Sequential = None Sequential = None
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811 import tensorflow # noqa
from tensorflow.python.keras import models as keras_saving from tensorflow.python.keras import models as keras_saving # noqa
except ImportError: except ImportError:
keras_saving = None keras_saving = None
@ -1387,7 +1388,8 @@ class PatchKerasModelIO(object):
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
elif Functional is not None: elif Functional is not None:
Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config) Functional._updated_config = _patched_call(
Functional._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'): if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__, Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
@ -1414,21 +1416,21 @@ class PatchKerasModelIO(object):
try: try:
# check if object already has InputModel # check if object already has InputModel
if not hasattr(self, 'trains_out_model'): if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None self.trains_out_model = []
# check if object already has InputModel # check if object already has InputModel
model_name_id = config.get('name', getattr(self, 'name', 'unknown')) model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
if self.trains_out_model is not None: if self.trains_out_model:
self.trains_out_model.config_dict = config self.trains_out_model[-1].config_dict = config
else: else:
# todo: support multiple models for the same task # todo: support multiple models for the same task
self.trains_out_model = OutputModel( self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task, task=PatchKerasModelIO.__main_task,
config_dict=config, config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras, framework=Framework.keras,
) ))
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1512,11 +1514,14 @@ class PatchKerasModelIO(object):
@staticmethod @staticmethod
def _save(original_fn, self, *args, **kwargs): def _save(original_fn, self, *args, **kwargs):
if hasattr(self, 'trains_out_model'): if hasattr(self, 'trains_out_model') and self.trains_out_model:
self.trains_out_model._processed = False # noinspection PyProtectedMember
self.trains_out_model[-1]._processed = False
original_fn(self, *args, **kwargs) original_fn(self, *args, **kwargs)
# no need to specially call, because the original save uses "save_model" which we overload # no need to specially call, because the original save uses "save_model" which we overload
if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed: # noinspection PyProtectedMember
if not hasattr(self, 'trains_out_model') or not self.trains_out_model or \
not hasattr(self.trains_out_model[-1], '_processed') or not self.trains_out_model[-1]._processed:
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs) PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod @staticmethod
@ -1544,28 +1549,38 @@ class PatchKerasModelIO(object):
# check if object already has InputModel # check if object already has InputModel
if not hasattr(self, 'trains_out_model'): if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None self.trains_out_model = []
# check if object already has InputModel # check if object already has InputModel, and we this has the same filename
if self.trains_out_model is not None: # (notice we Use Ptah on url for conforming)
self.trains_out_model.config_dict = config matched = None
else: if self.trains_out_model:
# find the right model
# noinspection PyProtectedMember
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
if matched:
self.trains_out_model.remove(matched[0])
self.trains_out_model.append(matched[0])
self.trains_out_model[-1].config_dict = config
if not matched:
model_name_id = getattr(self, 'name', 'unknown') model_name_id = getattr(self, 'name', 'unknown')
# todo: support multiple models for the same task # todo: support multiple models for the same task
self.trains_out_model = OutputModel( self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task, task=PatchKerasModelIO.__main_task,
config_dict=config, config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras, framework=Framework.keras,
) ))
# check if we have output storage # check if we have output storage
if self.trains_out_model.upload_storage_uri: if self.trains_out_model[-1].upload_storage_uri:
self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False) self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
else: else:
self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath) self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
# if anyone asks, we were here # if anyone asks, we were here
self.trains_out_model._processed = True # noinspection PyProtectedMember
self.trains_out_model[-1]._processed = True
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1624,9 +1639,9 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow import tensorflow # noqa
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from tensorflow.python.training.saver import Saver from tensorflow.python.training.saver import Saver # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save) Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
@ -1645,18 +1660,18 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow # noqa: F811 import tensorflow # noqa
from tensorflow.saved_model import save from tensorflow.saved_model import save # noqa
# actual import # actual import
from tensorflow.python.saved_model import save as saved_model from tensorflow.python.saved_model import save as saved_model # noqa
except ImportError: except ImportError:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow # noqa
from tensorflow.saved_model.experimental import save # noqa: F401 from tensorflow.saved_model.experimental import save # noqa
# actual import # actual import
import tensorflow.saved_model.experimental as saved_model import tensorflow.saved_model.experimental as saved_model # noqa
except ImportError: except ImportError:
saved_model = None saved_model = None
except Exception: except Exception:
@ -1671,11 +1686,11 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow # noqa: F811 import tensorflow # noqa
# actual import # actual import
from tensorflow.saved_model import load # noqa: F401 from tensorflow.saved_model import load # noqa
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import tensorflow.saved_model as saved_model_load import tensorflow.saved_model as saved_model_load # noqa
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load) saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
except ImportError: except ImportError:
pass pass
@ -1685,10 +1700,10 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow # noqa: F811 import tensorflow # noqa
# actual import # actual import
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from tensorflow.saved_model import loader as loader1 from tensorflow.saved_model import loader as loader1 # noqa
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load) loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
except ImportError: except ImportError:
pass pass
@ -1698,10 +1713,10 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow # noqa: F811 import tensorflow # noqa
# actual import # actual import
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from tensorflow.compat.v1.saved_model import loader as loader2 from tensorflow.compat.v1.saved_model import loader as loader2 # noqa
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load) loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
except ImportError: except ImportError:
pass pass
@ -1710,8 +1725,8 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import tensorflow # noqa: F401, F811 import tensorflow # noqa
from tensorflow.train import Checkpoint from tensorflow.train import Checkpoint # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save) Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
@ -1861,8 +1876,8 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401 import tensorflow # noqa
from tensorflow.python.training.tracking import util from tensorflow.python.training.tracking import util # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,

View File

@ -945,6 +945,7 @@ class OutputModel(BaseModel):
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict) config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
self._model_local_filename = None self._model_local_filename = None
self._last_uploaded_url = None
self._base_model = None self._base_model = None
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._floating_data = create_dummy_model( self._floating_data = create_dummy_model(
@ -1205,6 +1206,8 @@ class OutputModel(BaseModel):
else: else:
output_uri = None output_uri = None
self._last_uploaded_url = output_uri
if is_package: if is_package:
self._set_package_tag() self._set_package_tag()
@ -1433,6 +1436,9 @@ class OutputModel(BaseModel):
return True return True
def _get_last_uploaded_filename(self):
return Path(self._last_uploaded_url or self.url).name
class Waitable(object): class Waitable(object):
def wait(self, *_, **__): def wait(self, *_, **__):