Fix keras reusing model object only if the filename is the same (issue #252)

This commit is contained in:
allegroai 2020-11-25 14:55:45 +02:00
parent c75de848f5
commit f3638a7e5d
2 changed files with 84 additions and 63 deletions

View File

@ -11,6 +11,7 @@ from typing import Any
import numpy as np
import six
from PIL import Image
from pathlib2 import Path
from ...debugging.log import LoggerRoot
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
@ -19,7 +20,7 @@ from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework
try:
from google.protobuf.json_format import MessageToDict
from google.protobuf.json_format import MessageToDict # noqa
except ImportError:
MessageToDict = None
@ -840,7 +841,7 @@ class PatchSummaryToEventTransformer(object):
def _patch_summary_to_event_transformer():
if 'tensorflow' in sys.modules:
try:
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer # noqa
# only patch once
if PatchSummaryToEventTransformer.__original_getattribute is None:
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
@ -855,7 +856,7 @@ class PatchSummaryToEventTransformer(object):
# only patch once
if PatchSummaryToEventTransformer._original_add_eventT is None:
# noinspection PyUnresolvedReferences
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
from torch.utils.tensorboard.writer import FileWriter as FileWriterT # noqa
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
setattr(FileWriterT, 'trains', None)
@ -870,7 +871,7 @@ class PatchSummaryToEventTransformer(object):
# only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None:
# noinspection PyUnresolvedReferences
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX # noqa
PatchSummaryToEventTransformer.__original_getattributeX = \
SummaryToEventTransformerX.__getattribute__
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
@ -886,7 +887,7 @@ class PatchSummaryToEventTransformer(object):
try:
# only patch once
if PatchSummaryToEventTransformer._original_add_eventX is None:
from tensorboardX.writer import FileWriter as FileWriterX
from tensorboardX.writer import FileWriter as FileWriterX # noqa
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
setattr(FileWriterX, 'trains', None)
@ -1041,14 +1042,14 @@ class PatchModelCheckPointCallback(object):
callbacks = None
if is_keras:
try:
import keras.callbacks as callbacks # noqa: F401
import keras.callbacks as callbacks # noqa
except ImportError:
is_keras = False
if not is_keras and is_tf_keras:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401
import tensorflow.python.keras.callbacks as callbacks # noqa: F811
import tensorflow # noqa
import tensorflow.python.keras.callbacks as callbacks # noqa
except ImportError:
is_tf_keras = False
callbacks = None
@ -1129,8 +1130,8 @@ class PatchTensorFlowEager(object):
if 'tensorflow' in sys.modules:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401
from tensorflow.python.ops import gen_summary_ops # noqa: F401
import tensorflow # noqa
from tensorflow.python.ops import gen_summary_ops # noqa
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
@ -1160,12 +1161,12 @@ class PatchTensorFlowEager(object):
# check if we are in eager mode, let's get the global context lopdir
# noinspection PyBroadException
try:
from tensorflow.python.eager import context
from tensorflow.python.eager import context # noqa
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
except Exception:
# noinspection PyBroadException
try:
from tensorflow.python.ops.summary_ops_v2 import _summary_state
from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
except Exception:
logdir = None
@ -1300,19 +1301,19 @@ class PatchKerasModelIO(object):
def _patch_model_checkpoint():
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
try:
from keras.engine.network import Network
from keras.engine.network import Network # noqa
except ImportError:
Network = None
try:
from keras.engine.functional import Functional
from keras.engine.functional import Functional # noqa
except ImportError:
Functional = None
try:
from keras.engine.sequential import Sequential
from keras.engine.sequential import Sequential # noqa
except ImportError:
Sequential = None
try:
from keras import models as keras_saving
from keras import models as keras_saving # noqa
except ImportError:
keras_saving = None
# check that we are not patching anything twice
@ -1329,26 +1330,26 @@ class PatchKerasModelIO(object):
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811
from tensorflow.python.keras.engine.network import Network
import tensorflow # noqa
from tensorflow.python.keras.engine.network import Network # noqa
except ImportError:
Network = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811
from tensorflow.python.keras.engine.functional import Functional
import tensorflow # noqa
from tensorflow.python.keras.engine.functional import Functional # noqa
except ImportError:
Functional = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811
from tensorflow.python.keras.engine.sequential import Sequential
import tensorflow # noqa
from tensorflow.python.keras.engine.sequential import Sequential # noqa
except ImportError:
Sequential = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401, F811
from tensorflow.python.keras import models as keras_saving
import tensorflow # noqa
from tensorflow.python.keras import models as keras_saving # noqa
except ImportError:
keras_saving = None
@ -1387,7 +1388,8 @@ class PatchKerasModelIO(object):
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
elif Functional is not None:
Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config)
Functional._updated_config = _patched_call(
Functional._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
@ -1414,21 +1416,21 @@ class PatchKerasModelIO(object):
try:
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None
self.trains_out_model = []
# check if object already has InputModel
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
if self.trains_out_model is not None:
self.trains_out_model.config_dict = config
if self.trains_out_model:
self.trains_out_model[-1].config_dict = config
else:
# todo: support multiple models for the same task
self.trains_out_model = OutputModel(
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
))
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1512,11 +1514,14 @@ class PatchKerasModelIO(object):
@staticmethod
def _save(original_fn, self, *args, **kwargs):
if hasattr(self, 'trains_out_model'):
self.trains_out_model._processed = False
if hasattr(self, 'trains_out_model') and self.trains_out_model:
# noinspection PyProtectedMember
self.trains_out_model[-1]._processed = False
original_fn(self, *args, **kwargs)
# no need to specially call, because the original save uses "save_model" which we overload
if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed:
# noinspection PyProtectedMember
if not hasattr(self, 'trains_out_model') or not self.trains_out_model or \
not hasattr(self.trains_out_model[-1], '_processed') or not self.trains_out_model[-1]._processed:
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod
@ -1544,28 +1549,38 @@ class PatchKerasModelIO(object):
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None
self.trains_out_model = []
# check if object already has InputModel
if self.trains_out_model is not None:
self.trains_out_model.config_dict = config
else:
# check if object already has InputModel, and we this has the same filename
# (notice we Use Ptah on url for conforming)
matched = None
if self.trains_out_model:
# find the right model
# noinspection PyProtectedMember
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
if matched:
self.trains_out_model.remove(matched[0])
self.trains_out_model.append(matched[0])
self.trains_out_model[-1].config_dict = config
if not matched:
model_name_id = getattr(self, 'name', 'unknown')
# todo: support multiple models for the same task
self.trains_out_model = OutputModel(
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
))
# check if we have output storage
if self.trains_out_model.upload_storage_uri:
self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False)
if self.trains_out_model[-1].upload_storage_uri:
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
else:
self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath)
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
# if anyone asks, we were here
self.trains_out_model._processed = True
# noinspection PyProtectedMember
self.trains_out_model[-1]._processed = True
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1624,9 +1639,9 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
import tensorflow # noqa
# noinspection PyUnresolvedReferences
from tensorflow.python.training.saver import Saver
from tensorflow.python.training.saver import Saver # noqa
# noinspection PyBroadException
try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
@ -1645,18 +1660,18 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow # noqa: F811
from tensorflow.saved_model import save
import tensorflow # noqa
from tensorflow.saved_model import save # noqa
# actual import
from tensorflow.python.saved_model import save as saved_model
from tensorflow.python.saved_model import save as saved_model # noqa
except ImportError:
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
from tensorflow.saved_model.experimental import save # noqa: F401
import tensorflow # noqa
from tensorflow.saved_model.experimental import save # noqa
# actual import
import tensorflow.saved_model.experimental as saved_model
import tensorflow.saved_model.experimental as saved_model # noqa
except ImportError:
saved_model = None
except Exception:
@ -1671,11 +1686,11 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow # noqa: F811
import tensorflow # noqa
# actual import
from tensorflow.saved_model import load # noqa: F401
from tensorflow.saved_model import load # noqa
# noinspection PyUnresolvedReferences
import tensorflow.saved_model as saved_model_load
import tensorflow.saved_model as saved_model_load # noqa
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
except ImportError:
pass
@ -1685,10 +1700,10 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow # noqa: F811
import tensorflow # noqa
# actual import
# noinspection PyUnresolvedReferences
from tensorflow.saved_model import loader as loader1
from tensorflow.saved_model import loader as loader1 # noqa
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
except ImportError:
pass
@ -1698,10 +1713,10 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow # noqa: F811
import tensorflow # noqa
# actual import
# noinspection PyUnresolvedReferences
from tensorflow.compat.v1.saved_model import loader as loader2
from tensorflow.compat.v1.saved_model import loader as loader2 # noqa
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
except ImportError:
pass
@ -1710,8 +1725,8 @@ class PatchTensorflowModelIO(object):
# noinspection PyBroadException
try:
import tensorflow # noqa: F401, F811
from tensorflow.train import Checkpoint
import tensorflow # noqa
from tensorflow.train import Checkpoint # noqa
# noinspection PyBroadException
try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
@ -1861,8 +1876,8 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow # noqa: F401
from tensorflow.python.training.tracking import util
import tensorflow # noqa
from tensorflow.python.training.tracking import util # noqa
# noinspection PyBroadException
try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,

View File

@ -945,6 +945,7 @@ class OutputModel(BaseModel):
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
self._model_local_filename = None
self._last_uploaded_url = None
self._base_model = None
# noinspection PyProtectedMember
self._floating_data = create_dummy_model(
@ -1205,6 +1206,8 @@ class OutputModel(BaseModel):
else:
output_uri = None
self._last_uploaded_url = output_uri
if is_package:
self._set_package_tag()
@ -1433,6 +1436,9 @@ class OutputModel(BaseModel):
return True
def _get_last_uploaded_filename(self):
return Path(self._last_uploaded_url or self.url).name
class Waitable(object):
def wait(self, *_, **__):