diff --git a/trains/binding/environ_bind.py b/trains/binding/environ_bind.py index 32faf8d1..b035d609 100644 --- a/trains/binding/environ_bind.py +++ b/trains/binding/environ_bind.py @@ -66,6 +66,8 @@ class PatchOsFork(object): task = Task.init() task.get_logger().flush() + # Hack: now make sure we setup the reporter thread + task._setup_reporter() # if we got here patch the os._exit of our instance to call us def _at_exit_callback(*args, **kwargs): # call at exit manually diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index ea6b887c..b7937ea8 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -7,6 +7,7 @@ from pathlib2 import Path from ...config import running_remotely from ...model import InputModel, OutputModel +from ...backend_interface.model import Model TrainsFrameworkAdapter = 'TrainsFrameworkAdapter' _recursion_guard = {} @@ -75,15 +76,28 @@ class WeightsFileHandler(object): config_text = trains_in_model.config_text if trains_in_model else None except Exception: config_text = None - trains_in_model = InputModel.import_model( - weights_url=filepath, - config_dict=config_dict, - config_text=config_text, - name=task.name + ' ' + model_name_id, - label_enumeration=task.get_labels_enumeration(), - framework=framework, - create_as_published=False, - ) + + # check if we already have the model object: + model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None)) + if model_id: + # noinspection PyBroadException + try: + trains_in_model = InputModel(model_id) + except Exception: + model_id = None + + # if we do not, we need to import the model + if not model_id: + trains_in_model = InputModel.import_model( + weights_url=filepath, + config_dict=config_dict, + config_text=config_text, + name=task.name + ' ' + model_name_id, + label_enumeration=task.get_labels_enumeration(), + framework=framework, + create_as_published=False, + ) + # noinspection PyBroadException try: ref_model = weakref.ref(model) @@ -94,7 +108,8 @@ class WeightsFileHandler(object): task.connect(trains_in_model) # if we are running remotely we should deserialize the object # because someone might have changed the config_dict - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # reload the model model_config = trains_in_model.config_dict # verify that this is the same model so we are not deserializing a diff model diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py index a3354f0f..9a80d4ed 100644 --- a/trains/binding/frameworks/pytorch_bind.py +++ b/trains/binding/frameworks/pytorch_bind.py @@ -47,16 +47,22 @@ class PatchPyTorchModelIO(PatchBaseModelIO): if not PatchPyTorchModelIO.__main_task: return ret - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - # noinspection PyBroadException - try: - f.flush() - except Exception: - pass - else: + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + filename = f.name + # noinspection PyBroadException + try: + f.flush() + except Exception: + pass + else: + filename = None + except Exception: filename = None # give the model a descriptive name based on the file name @@ -65,31 +71,40 @@ class PatchPyTorchModelIO(PatchBaseModelIO): model_name = Path(filename).stem except Exception: model_name = None + WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name) + return ret @staticmethod def _load(original_fn, f, *args, **kwargs): - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - else: - filename = None - if not PatchPyTorchModelIO.__main_task: return original_fn(f, *args, **kwargs) + # noinspection PyBroadException + try: + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'as_posix'): + filename = f.as_posix() + elif hasattr(f, 'name'): + filename = f.name + else: + filename = None + except Exception: + filename = None + # register input model empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) model = original_fn(filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail - model = original_fn(filename or f, *args, **kwargs) + model = original_fn(f, *args, **kwargs) WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index d0a0cccc..1abf13be 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -63,7 +63,7 @@ class EventTrainsWriter(object): return self.variants.copy() def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant', - logdir_header='series'): + logdir_header='series', auto_reduce_num_split=False): """ Split a tf.summary tag line to variant and metric. Variant is the first part of the split tag, metric is the second. @@ -74,9 +74,13 @@ class EventTrainsWriter(object): :param str default_title: variant to use in case no variant can be inferred automatically :param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header, if 'title_last' then title=header title, if 'title' then title=title header + :param boolean auto_reduce_num_split: if True and the tag is split for less parts then requested, + then requested number of split parts is adjusted. :return: (str, str) variant and metric """ splitted_tag = tag.split(split_char) + if auto_reduce_num_split and num_split_parts > len(splitted_tag)-1: + num_split_parts = max(1, len(splitted_tag)-1) series = join_char.join(splitted_tag[-num_split_parts:]) title = join_char.join(splitted_tag[:-num_split_parts]) or default_title @@ -195,7 +199,8 @@ class EventTrainsWriter(object): if img_data_np is None: return - title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title') + title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title', + auto_reduce_num_split=True) if img_data_np.dtype != np.uint8: # assume scale 0-1 img_data_np = (img_data_np * 255).astype(np.uint8) @@ -998,11 +1003,19 @@ class PatchKerasModelIO(object): if Sequential is not None: Sequential._updated_config = _patched_call(Sequential._updated_config, PatchKerasModelIO._updated_config) - Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config) + if hasattr(Sequential.from_config, '__func__'): + Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__, + PatchKerasModelIO._from_config) + else: + Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config) if Network is not None: Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config) - Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config) + if hasattr(Sequential.from_config, '__func__'): + Network.from_config.__func__ = _patched_call(Network.from_config.__func__, + PatchKerasModelIO._from_config) + else: + Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config) Network.save = _patched_call(Network.save, PatchKerasModelIO._save) Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights) Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights) @@ -1072,7 +1085,8 @@ class PatchKerasModelIO(object): PatchKerasModelIO.__main_task.connect(self.trains_in_model) # if we are running remotely we should deserialize the object # because someone might have changed the configuration - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # reload the model model_config = self.trains_in_model.config_dict # verify that this is the same model so we are not deserializing a diff model @@ -1100,7 +1114,8 @@ class PatchKerasModelIO(object): # get filepath filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0] - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # register/load model weights filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task) @@ -1183,7 +1198,8 @@ class PatchKerasModelIO(object): return original_fn(filepath, *args, **kwargs) empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # register/load model weights filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task) @@ -1351,7 +1367,8 @@ class PatchTensorflowModelIO(object): if PatchTensorflowModelIO.__main_task is None: return original_fn(self, sess, save_path, *args, **kwargs) - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # register/load model weights save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, PatchTensorflowModelIO.__main_task) @@ -1372,7 +1389,8 @@ class PatchTensorflowModelIO(object): # register input model empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow, PatchTensorflowModelIO.__main_task) model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) @@ -1415,7 +1433,8 @@ class PatchTensorflowModelIO(object): # register input model empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow, PatchTensorflowModelIO.__main_task) model = original_fn(self, save_path, *args, **kwargs) diff --git a/trains/binding/frameworks/xgboost_bind.py b/trains/binding/frameworks/xgboost_bind.py index 03b3321b..a7f14c25 100644 --- a/trains/binding/frameworks/xgboost_bind.py +++ b/trains/binding/frameworks/xgboost_bind.py @@ -82,7 +82,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO): # register input model empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task) model = original_fn(filename or f, *args, **kwargs) diff --git a/trains/binding/joblib_bind.py b/trains/binding/joblib_bind.py index 38e42b26..80dd0ac2 100644 --- a/trains/binding/joblib_bind.py +++ b/trains/binding/joblib_bind.py @@ -90,7 +90,8 @@ class PatchedJoblib(object): # register input model empty = _Empty() - if running_remotely(): + # Hack: disabled + if False and running_remotely(): # we assume scikit-learn, for the time being current_framework = Framework.scikitlearn filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,