diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 21135010..a9b00267 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -11,6 +11,7 @@ from typing import Any import numpy as np import six from PIL import Image +from pathlib2 import Path from ...debugging.log import LoggerRoot from ..frameworks import _patched_call, WeightsFileHandler, _Empty @@ -19,7 +20,7 @@ from ...config import running_remotely from ...model import InputModel, OutputModel, Framework try: - from google.protobuf.json_format import MessageToDict + from google.protobuf.json_format import MessageToDict # noqa except ImportError: MessageToDict = None @@ -840,7 +841,7 @@ class PatchSummaryToEventTransformer(object): def _patch_summary_to_event_transformer(): if 'tensorflow' in sys.modules: try: - from tensorflow.python.summary.writer.writer import SummaryToEventTransformer + from tensorflow.python.summary.writer.writer import SummaryToEventTransformer # noqa # only patch once if PatchSummaryToEventTransformer.__original_getattribute is None: PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__ @@ -855,7 +856,7 @@ class PatchSummaryToEventTransformer(object): # only patch once if PatchSummaryToEventTransformer._original_add_eventT is None: # noinspection PyUnresolvedReferences - from torch.utils.tensorboard.writer import FileWriter as FileWriterT + from torch.utils.tensorboard.writer import FileWriter as FileWriterT # noqa PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT setattr(FileWriterT, 'trains', None) @@ -870,7 +871,7 @@ class PatchSummaryToEventTransformer(object): # only patch once if PatchSummaryToEventTransformer.__original_getattributeX is None: # noinspection PyUnresolvedReferences - from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX + from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX # noqa PatchSummaryToEventTransformer.__original_getattributeX = \ SummaryToEventTransformerX.__getattribute__ SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX @@ -886,7 +887,7 @@ class PatchSummaryToEventTransformer(object): try: # only patch once if PatchSummaryToEventTransformer._original_add_eventX is None: - from tensorboardX.writer import FileWriter as FileWriterX + from tensorboardX.writer import FileWriter as FileWriterX # noqa PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX setattr(FileWriterX, 'trains', None) @@ -1041,14 +1042,14 @@ class PatchModelCheckPointCallback(object): callbacks = None if is_keras: try: - import keras.callbacks as callbacks # noqa: F401 + import keras.callbacks as callbacks # noqa except ImportError: is_keras = False if not is_keras and is_tf_keras: try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401 - import tensorflow.python.keras.callbacks as callbacks # noqa: F811 + import tensorflow # noqa + import tensorflow.python.keras.callbacks as callbacks # noqa except ImportError: is_tf_keras = False callbacks = None @@ -1129,8 +1130,8 @@ class PatchTensorFlowEager(object): if 'tensorflow' in sys.modules: try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401 - from tensorflow.python.ops import gen_summary_ops # noqa: F401 + import tensorflow # noqa + from tensorflow.python.ops import gen_summary_ops # noqa PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary @@ -1160,12 +1161,12 @@ class PatchTensorFlowEager(object): # check if we are in eager mode, let's get the global context lopdir # noinspection PyBroadException try: - from tensorflow.python.eager import context + from tensorflow.python.eager import context # noqa logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir') except Exception: # noinspection PyBroadException try: - from tensorflow.python.ops.summary_ops_v2 import _summary_state + from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa logdir = _summary_state.writer._init_op_fn.keywords.get('logdir') except Exception: logdir = None @@ -1300,19 +1301,19 @@ class PatchKerasModelIO(object): def _patch_model_checkpoint(): if 'keras' in sys.modules and not PatchKerasModelIO.__patched_keras: try: - from keras.engine.network import Network + from keras.engine.network import Network # noqa except ImportError: Network = None try: - from keras.engine.functional import Functional + from keras.engine.functional import Functional # noqa except ImportError: Functional = None try: - from keras.engine.sequential import Sequential + from keras.engine.sequential import Sequential # noqa except ImportError: Sequential = None try: - from keras import models as keras_saving + from keras import models as keras_saving # noqa except ImportError: keras_saving = None # check that we are not patching anything twice @@ -1329,26 +1330,26 @@ class PatchKerasModelIO(object): if 'tensorflow' in sys.modules and not PatchKerasModelIO.__patched_tensorflow: try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401, F811 - from tensorflow.python.keras.engine.network import Network + import tensorflow # noqa + from tensorflow.python.keras.engine.network import Network # noqa except ImportError: Network = None try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401, F811 - from tensorflow.python.keras.engine.functional import Functional + import tensorflow # noqa + from tensorflow.python.keras.engine.functional import Functional # noqa except ImportError: Functional = None try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401, F811 - from tensorflow.python.keras.engine.sequential import Sequential + import tensorflow # noqa + from tensorflow.python.keras.engine.sequential import Sequential # noqa except ImportError: Sequential = None try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401, F811 - from tensorflow.python.keras import models as keras_saving + import tensorflow # noqa + from tensorflow.python.keras import models as keras_saving # noqa except ImportError: keras_saving = None @@ -1387,7 +1388,8 @@ class PatchKerasModelIO(object): Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) elif Functional is not None: - Functional._updated_config = _patched_call(Functional._updated_config, PatchKerasModelIO._updated_config) + Functional._updated_config = _patched_call( + Functional._updated_config, PatchKerasModelIO._updated_config) if hasattr(Sequential.from_config, '__func__'): # noinspection PyUnresolvedReferences Functional.from_config = classmethod(_patched_call(Functional.from_config.__func__, @@ -1414,21 +1416,21 @@ class PatchKerasModelIO(object): try: # check if object already has InputModel if not hasattr(self, 'trains_out_model'): - self.trains_out_model = None + self.trains_out_model = [] # check if object already has InputModel model_name_id = config.get('name', getattr(self, 'name', 'unknown')) - if self.trains_out_model is not None: - self.trains_out_model.config_dict = config + if self.trains_out_model: + self.trains_out_model[-1].config_dict = config else: # todo: support multiple models for the same task - self.trains_out_model = OutputModel( + self.trains_out_model.append(OutputModel( task=PatchKerasModelIO.__main_task, config_dict=config, name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), framework=Framework.keras, - ) + )) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1512,11 +1514,14 @@ class PatchKerasModelIO(object): @staticmethod def _save(original_fn, self, *args, **kwargs): - if hasattr(self, 'trains_out_model'): - self.trains_out_model._processed = False + if hasattr(self, 'trains_out_model') and self.trains_out_model: + # noinspection PyProtectedMember + self.trains_out_model[-1]._processed = False original_fn(self, *args, **kwargs) # no need to specially call, because the original save uses "save_model" which we overload - if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed: + # noinspection PyProtectedMember + if not hasattr(self, 'trains_out_model') or not self.trains_out_model or \ + not hasattr(self.trains_out_model[-1], '_processed') or not self.trains_out_model[-1]._processed: PatchKerasModelIO._update_outputmodel(self, *args, **kwargs) @staticmethod @@ -1544,28 +1549,38 @@ class PatchKerasModelIO(object): # check if object already has InputModel if not hasattr(self, 'trains_out_model'): - self.trains_out_model = None + self.trains_out_model = [] - # check if object already has InputModel - if self.trains_out_model is not None: - self.trains_out_model.config_dict = config - else: + # check if object already has InputModel, and we this has the same filename + # (notice we Use Ptah on url for conforming) + matched = None + if self.trains_out_model: + # find the right model + # noinspection PyProtectedMember + matched = [m for m in self.trains_out_model if m._get_last_uploaded_filename() == Path(filepath).name] + if matched: + self.trains_out_model.remove(matched[0]) + self.trains_out_model.append(matched[0]) + self.trains_out_model[-1].config_dict = config + + if not matched: model_name_id = getattr(self, 'name', 'unknown') # todo: support multiple models for the same task - self.trains_out_model = OutputModel( + self.trains_out_model.append(OutputModel( task=PatchKerasModelIO.__main_task, config_dict=config, name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), framework=Framework.keras, - ) + )) # check if we have output storage - if self.trains_out_model.upload_storage_uri: - self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False) + if self.trains_out_model[-1].upload_storage_uri: + self.trains_out_model[-1].update_weights(weights_filename=filepath, auto_delete_file=False) else: - self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath) + self.trains_out_model[-1].update_weights(weights_filename=None, register_uri=filepath) # if anyone asks, we were here - self.trains_out_model._processed = True + # noinspection PyProtectedMember + self.trains_out_model[-1]._processed = True except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1624,9 +1639,9 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: # hack: make sure tensorflow.__init__ is called - import tensorflow + import tensorflow # noqa # noinspection PyUnresolvedReferences - from tensorflow.python.training.saver import Saver + from tensorflow.python.training.saver import Saver # noqa # noinspection PyBroadException try: Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save) @@ -1645,18 +1660,18 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: # make sure we import the correct version of save - import tensorflow # noqa: F811 - from tensorflow.saved_model import save + import tensorflow # noqa + from tensorflow.saved_model import save # noqa # actual import - from tensorflow.python.saved_model import save as saved_model + from tensorflow.python.saved_model import save as saved_model # noqa except ImportError: # noinspection PyBroadException try: # make sure we import the correct version of save - import tensorflow - from tensorflow.saved_model.experimental import save # noqa: F401 + import tensorflow # noqa + from tensorflow.saved_model.experimental import save # noqa # actual import - import tensorflow.saved_model.experimental as saved_model + import tensorflow.saved_model.experimental as saved_model # noqa except ImportError: saved_model = None except Exception: @@ -1671,11 +1686,11 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: # make sure we import the correct version of save - import tensorflow # noqa: F811 + import tensorflow # noqa # actual import - from tensorflow.saved_model import load # noqa: F401 + from tensorflow.saved_model import load # noqa # noinspection PyUnresolvedReferences - import tensorflow.saved_model as saved_model_load + import tensorflow.saved_model as saved_model_load # noqa saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load) except ImportError: pass @@ -1685,10 +1700,10 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: # make sure we import the correct version of save - import tensorflow # noqa: F811 + import tensorflow # noqa # actual import # noinspection PyUnresolvedReferences - from tensorflow.saved_model import loader as loader1 + from tensorflow.saved_model import loader as loader1 # noqa loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load) except ImportError: pass @@ -1698,10 +1713,10 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: # make sure we import the correct version of save - import tensorflow # noqa: F811 + import tensorflow # noqa # actual import # noinspection PyUnresolvedReferences - from tensorflow.compat.v1.saved_model import loader as loader2 + from tensorflow.compat.v1.saved_model import loader as loader2 # noqa loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load) except ImportError: pass @@ -1710,8 +1725,8 @@ class PatchTensorflowModelIO(object): # noinspection PyBroadException try: - import tensorflow # noqa: F401, F811 - from tensorflow.train import Checkpoint + import tensorflow # noqa + from tensorflow.train import Checkpoint # noqa # noinspection PyBroadException try: Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save) @@ -1861,8 +1876,8 @@ class PatchTensorflow2ModelIO(object): # noinspection PyBroadException try: # hack: make sure tensorflow.__init__ is called - import tensorflow # noqa: F401 - from tensorflow.python.training.tracking import util + import tensorflow # noqa + from tensorflow.python.training.tracking import util # noqa # noinspection PyBroadException try: util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, diff --git a/trains/model.py b/trains/model.py index d774394a..f13064d4 100644 --- a/trains/model.py +++ b/trains/model.py @@ -945,6 +945,7 @@ class OutputModel(BaseModel): config_text = self._resolve_config(config_text=config_text, config_dict=config_dict) self._model_local_filename = None + self._last_uploaded_url = None self._base_model = None # noinspection PyProtectedMember self._floating_data = create_dummy_model( @@ -1205,6 +1206,8 @@ class OutputModel(BaseModel): else: output_uri = None + self._last_uploaded_url = output_uri + if is_package: self._set_package_tag() @@ -1433,6 +1436,9 @@ class OutputModel(BaseModel): return True + def _get_last_uploaded_filename(self): + return Path(self._last_uploaded_url or self.url).name + class Waitable(object): def wait(self, *_, **__):