mirror of
https://github.com/clearml/clearml
synced 2025-05-07 06:14:31 +00:00
Fix keras reusing model object only if the filename is the same (issue #252)
This commit is contained in:
parent
c75de848f5
commit
f3638a7e5d
@ -11,6 +11,7 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
from ...debugging.log import LoggerRoot
|
from ...debugging.log import LoggerRoot
|
||||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
@ -19,7 +20,7 @@ from ...config import running_remotely
|
|||||||
from ...model import InputModel, OutputModel, Framework
|
from ...model import InputModel, OutputModel, Framework
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
MessageToDict = None
|
MessageToDict = None
|
||||||
|
|
||||||
@ -840,7 +841,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
def _patch_summary_to_event_transformer():
|
def _patch_summary_to_event_transformer():
|
||||||
if 'tensorflow' in sys.modules:
|
if 'tensorflow' in sys.modules:
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer
|
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer # noqa
|
||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer.__original_getattribute is None:
|
if PatchSummaryToEventTransformer.__original_getattribute is None:
|
||||||
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
|
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
|
||||||
@ -855,7 +856,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer._original_add_eventT is None:
|
if PatchSummaryToEventTransformer._original_add_eventT is None:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
|
from torch.utils.tensorboard.writer import FileWriter as FileWriterT # noqa
|
||||||
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
|
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
|
||||||
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
|
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
|
||||||
setattr(FileWriterT, 'trains', None)
|
setattr(FileWriterT, 'trains', None)
|
||||||
@ -870,7 +871,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX # noqa
|
||||||
PatchSummaryToEventTransformer.__original_getattributeX = \
|
PatchSummaryToEventTransformer.__original_getattributeX = \
|
||||||
SummaryToEventTransformerX.__getattribute__
|
SummaryToEventTransformerX.__getattribute__
|
||||||
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
||||||
@ -886,7 +887,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
try:
|
try:
|
||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer._original_add_eventX is None:
|
if PatchSummaryToEventTransformer._original_add_eventX is None:
|
||||||
from tensorboardX.writer import FileWriter as FileWriterX
|
from tensorboardX.writer import FileWriter as FileWriterX # noqa
|
||||||
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
|
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
|
||||||
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
|
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
|
||||||
setattr(FileWriterX, 'trains', None)
|
setattr(FileWriterX, 'trains', None)
|
||||||
@ -1041,14 +1042,14 @@ class PatchModelCheckPointCallback(object):
|
|||||||
callbacks = None
|
callbacks = None
|
||||||
if is_keras:
|
if is_keras:
|
||||||
try:
|
try:
|
||||||
import keras.callbacks as callbacks # noqa: F401
|
import keras.callbacks as callbacks # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
is_keras = False
|
is_keras = False
|
||||||
if not is_keras and is_tf_keras:
|
if not is_keras and is_tf_keras:
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401
|
import tensorflow # noqa
|
||||||
import tensorflow.python.keras.callbacks as callbacks # noqa: F811
|
import tensorflow.python.keras.callbacks as callbacks # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
is_tf_keras = False
|
is_tf_keras = False
|
||||||
callbacks = None
|
callbacks = None
|
||||||
@ -1129,8 +1130,8 @@ class PatchTensorFlowEager(object):
|
|||||||
if 'tensorflow' in sys.modules:
|
if 'tensorflow' in sys.modules:
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401
|
import tensorflow # noqa
|
||||||
from tensorflow.python.ops import gen_summary_ops # noqa: F401
|
from tensorflow.python.ops import gen_summary_ops # noqa
|
||||||
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
|
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
|
||||||
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
|
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
|
||||||
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
|
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
|
||||||
@ -1160,12 +1161,12 @@ class PatchTensorFlowEager(object):
|
|||||||
# check if we are in eager mode, let's get the global context lopdir
|
# check if we are in eager mode, let's get the global context lopdir
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context # noqa
|
||||||
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
|
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
|
||||||
except Exception:
|
except Exception:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from tensorflow.python.ops.summary_ops_v2 import _summary_state
|
from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
|
||||||
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
|
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
|
||||||
except Exception:
|
except Exception:
|
||||||
logdir = None
|
logdir = None
|
||||||
@ -1300,19 +1301,19 @@ class PatchKerasModelIO(object):
|
|||||||
def _patch_model_checkpoint():
|
def _patch_model_checkpoint():
|
||||||
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
|
if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras:
|
||||||
try:
|
try:
|
||||||
from keras.engine.network import Network
|
from keras.engine.network import Network # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Network = None
|
Network = None
|
||||||
try:
|
try:
|
||||||
from keras.engine.functional import Functional
|
from keras.engine.functional import Functional # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Functional = None
|
Functional = None
|
||||||
try:
|
try:
|
||||||
from keras.engine.sequential import Sequential
|
from keras.engine.sequential import Sequential # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Sequential = None
|
Sequential = None
|
||||||
try:
|
try:
|
||||||
from keras import models as keras_saving
|
from keras import models as keras_saving # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
keras_saving = None
|
keras_saving = None
|
||||||
# check that we are not patching anything twice
|
# check that we are not patching anything twice
|
||||||
@ -1329,26 +1330,26 @@ class PatchKerasModelIO(object):
|
|||||||
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow:
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401, F811
|
import tensorflow # noqa
|
||||||
from tensorflow.python.keras.engine.network import Network
|
from tensorflow.python.keras.engine.network import Network # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Network = None
|
Network = None
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401, F811
|
import tensorflow # noqa
|
||||||
from tensorflow.python.keras.engine.functional import Functional
|
from tensorflow.python.keras.engine.functional import Functional # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Functional = None
|
Functional = None
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401, F811
|
import tensorflow # noqa
|
||||||
from tensorflow.python.keras.engine.sequential import Sequential
|
from tensorflow.python.keras.engine.sequential import Sequential # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Sequential = None
|
Sequential = None
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401, F811
|
import tensorflow # noqa
|
||||||
from tensorflow.python.keras import models as keras_saving
|
from tensorflow.python.keras import models as keras_saving # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
keras_saving = None
|
keras_saving = None
|
||||||
|
|
||||||
@ -1387,7 +1388,8 @@ class PatchKerasModelIO(object):
|
|||||||
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
||||||
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
||||||
elif Functional is not None:
|
elif Functional is not None:
|
||||||
Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config)
|
Functional._updated_config = _patched_call(
|
||||||
|
Functional._updated_config, PatchKerasModelIO._updated_config)
|
||||||
if hasattr(Sequential.from_config, '__func__'):
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
|
Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__,
|
||||||
@ -1414,21 +1416,21 @@ class PatchKerasModelIO(object):
|
|||||||
try:
|
try:
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if not hasattr(self, 'trains_out_model'):
|
if not hasattr(self, 'trains_out_model'):
|
||||||
self.trains_out_model = None
|
self.trains_out_model = []
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
|
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
|
||||||
if self.trains_out_model is not None:
|
if self.trains_out_model:
|
||||||
self.trains_out_model.config_dict = config
|
self.trains_out_model[-1].config_dict = config
|
||||||
else:
|
else:
|
||||||
# todo: support multiple models for the same task
|
# todo: support multiple models for the same task
|
||||||
self.trains_out_model = OutputModel(
|
self.trains_out_model.append(OutputModel(
|
||||||
task=PatchKerasModelIO.__main_task,
|
task=PatchKerasModelIO.__main_task,
|
||||||
config_dict=config,
|
config_dict=config,
|
||||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
)
|
))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@ -1512,11 +1514,14 @@ class PatchKerasModelIO(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, *args, **kwargs):
|
def _save(original_fn, self, *args, **kwargs):
|
||||||
if hasattr(self, 'trains_out_model'):
|
if hasattr(self, 'trains_out_model') and self.trains_out_model:
|
||||||
self.trains_out_model._processed = False
|
# noinspection PyProtectedMember
|
||||||
|
self.trains_out_model[-1]._processed = False
|
||||||
original_fn(self, *args, **kwargs)
|
original_fn(self, *args, **kwargs)
|
||||||
# no need to specially call, because the original save uses "save_model" which we overload
|
# no need to specially call, because the original save uses "save_model" which we overload
|
||||||
if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed:
|
# noinspection PyProtectedMember
|
||||||
|
if not hasattr(self, 'trains_out_model') or not self.trains_out_model or \
|
||||||
|
not hasattr(self.trains_out_model[-1], '_processed') or not self.trains_out_model[-1]._processed:
|
||||||
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1544,28 +1549,38 @@ class PatchKerasModelIO(object):
|
|||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if not hasattr(self, 'trains_out_model'):
|
if not hasattr(self, 'trains_out_model'):
|
||||||
self.trains_out_model = None
|
self.trains_out_model = []
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel, and we this has the same filename
|
||||||
if self.trains_out_model is not None:
|
# (notice we Use Ptah on url for conforming)
|
||||||
self.trains_out_model.config_dict = config
|
matched = None
|
||||||
else:
|
if self.trains_out_model:
|
||||||
|
# find the right model
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name]
|
||||||
|
if matched:
|
||||||
|
self.trains_out_model.remove(matched[0])
|
||||||
|
self.trains_out_model.append(matched[0])
|
||||||
|
self.trains_out_model[-1].config_dict = config
|
||||||
|
|
||||||
|
if not matched:
|
||||||
model_name_id = getattr(self, 'name', 'unknown')
|
model_name_id = getattr(self, 'name', 'unknown')
|
||||||
# todo: support multiple models for the same task
|
# todo: support multiple models for the same task
|
||||||
self.trains_out_model = OutputModel(
|
self.trains_out_model.append(OutputModel(
|
||||||
task=PatchKerasModelIO.__main_task,
|
task=PatchKerasModelIO.__main_task,
|
||||||
config_dict=config,
|
config_dict=config,
|
||||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
)
|
))
|
||||||
# check if we have output storage
|
# check if we have output storage
|
||||||
if self.trains_out_model.upload_storage_uri:
|
if self.trains_out_model[-1].upload_storage_uri:
|
||||||
self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False)
|
self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False)
|
||||||
else:
|
else:
|
||||||
self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath)
|
self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath)
|
||||||
# if anyone asks, we were here
|
# if anyone asks, we were here
|
||||||
self.trains_out_model._processed = True
|
# noinspection PyProtectedMember
|
||||||
|
self.trains_out_model[-1]._processed = True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@ -1624,9 +1639,9 @@ class PatchTensorflowModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow
|
import tensorflow # noqa
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.python.training.saver import Saver
|
from tensorflow.python.training.saver import Saver # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
||||||
@ -1645,18 +1660,18 @@ class PatchTensorflowModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa
|
||||||
from tensorflow.saved_model import save
|
from tensorflow.saved_model import save # noqa
|
||||||
# actual import
|
# actual import
|
||||||
from tensorflow.python.saved_model import save as saved_model
|
from tensorflow.python.saved_model import save as saved_model # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow # noqa
|
||||||
from tensorflow.saved_model.experimental import save # noqa: F401
|
from tensorflow.saved_model.experimental import save # noqa
|
||||||
# actual import
|
# actual import
|
||||||
import tensorflow.saved_model.experimental as saved_model
|
import tensorflow.saved_model.experimental as saved_model # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
saved_model = None
|
saved_model = None
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1671,11 +1686,11 @@ class PatchTensorflowModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa
|
||||||
# actual import
|
# actual import
|
||||||
from tensorflow.saved_model import load # noqa: F401
|
from tensorflow.saved_model import load # noqa
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import tensorflow.saved_model as saved_model_load
|
import tensorflow.saved_model as saved_model_load # noqa
|
||||||
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@ -1685,10 +1700,10 @@ class PatchTensorflowModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa
|
||||||
# actual import
|
# actual import
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.saved_model import loader as loader1
|
from tensorflow.saved_model import loader as loader1 # noqa
|
||||||
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
|
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@ -1698,10 +1713,10 @@ class PatchTensorflowModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa
|
||||||
# actual import
|
# actual import
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.compat.v1.saved_model import loader as loader2
|
from tensorflow.compat.v1.saved_model import loader as loader2 # noqa
|
||||||
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
|
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@ -1710,8 +1725,8 @@ class PatchTensorflowModelIO(object):
|
|||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import tensorflow # noqa: F401, F811
|
import tensorflow # noqa
|
||||||
from tensorflow.train import Checkpoint
|
from tensorflow.train import Checkpoint # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
||||||
@ -1861,8 +1876,8 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow # noqa: F401
|
import tensorflow # noqa
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||||
|
@ -945,6 +945,7 @@ class OutputModel(BaseModel):
|
|||||||
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||||
|
|
||||||
self._model_local_filename = None
|
self._model_local_filename = None
|
||||||
|
self._last_uploaded_url = None
|
||||||
self._base_model = None
|
self._base_model = None
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
self._floating_data = create_dummy_model(
|
self._floating_data = create_dummy_model(
|
||||||
@ -1205,6 +1206,8 @@ class OutputModel(BaseModel):
|
|||||||
else:
|
else:
|
||||||
output_uri = None
|
output_uri = None
|
||||||
|
|
||||||
|
self._last_uploaded_url = output_uri
|
||||||
|
|
||||||
if is_package:
|
if is_package:
|
||||||
self._set_package_tag()
|
self._set_package_tag()
|
||||||
|
|
||||||
@ -1433,6 +1436,9 @@ class OutputModel(BaseModel):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_last_uploaded_filename(self):
|
||||||
|
return Path(self._last_uploaded_url or self.url).name
|
||||||
|
|
||||||
|
|
||||||
class Waitable(object):
|
class Waitable(object):
|
||||||
def wait(self, *_, **__):
|
def wait(self, *_, **__):
|
||||||
|
Loading…
Reference in New Issue
Block a user