From a77b470500a9d360586e9159725c80915aae5412 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 20 Jun 2019 01:50:40 +0300 Subject: [PATCH] Refactored binding, better support for matplotlib jupyter binding --- trains/{utilities => binding}/absl_bind.py | 2 +- trains/binding/frameworks/__init__.py | 178 +++++++++++ trains/binding/frameworks/pytorch_bind.py | 101 ++++++ .../frameworks/tensorflow_bind.py} | 295 +++--------------- trains/binding/import_bind.py | 73 +++++ .../{utilities => binding}/matplotlib_bind.py | 35 ++- trains/task.py | 9 +- 7 files changed, 428 insertions(+), 265 deletions(-) rename trains/{utilities => binding}/absl_bind.py (98%) create mode 100644 trains/binding/frameworks/__init__.py create mode 100644 trains/binding/frameworks/pytorch_bind.py rename trains/{utilities/frameworks.py => binding/frameworks/tensorflow_bind.py} (85%) create mode 100644 trains/binding/import_bind.py rename trains/{utilities => binding}/matplotlib_bind.py (85%) diff --git a/trains/utilities/absl_bind.py b/trains/binding/absl_bind.py similarity index 98% rename from trains/utilities/absl_bind.py rename to trains/binding/absl_bind.py index dc191345..c94791af 100644 --- a/trains/utilities/absl_bind.py +++ b/trains/binding/absl_bind.py @@ -1,5 +1,5 @@ """ absl-py FLAGS binding utility functions """ -from trains.backend_interface.task.args import _Arguments +from ..backend_interface.task.args import _Arguments from ..config import running_remotely diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py new file mode 100644 index 00000000..ea6b887c --- /dev/null +++ b/trains/binding/frameworks/__init__.py @@ -0,0 +1,178 @@ +import threading +import weakref +from logging import getLogger + +import six +from pathlib2 import Path + +from ...config import running_remotely +from ...model import InputModel, OutputModel + +TrainsFrameworkAdapter = 'TrainsFrameworkAdapter' +_recursion_guard = {} + + +def _patched_call(original_fn, patched_fn): + def _inner_patch(*args, **kwargs): + ident = threading._get_ident() if six.PY2 else threading.get_ident() + if ident in _recursion_guard: + return original_fn(*args, **kwargs) + _recursion_guard[ident] = 1 + ret = None + try: + ret = patched_fn(original_fn, *args, **kwargs) + except Exception as ex: + raise ex + finally: + try: + _recursion_guard.pop(ident) + except KeyError: + pass + return ret + + return _inner_patch + + +class _Empty(object): + def __init__(self): + self.trains_in_model = None + + +class WeightsFileHandler(object): + _model_out_store_lookup = {} + _model_in_store_lookup = {} + _model_store_lookup_lock = threading.Lock() + + @staticmethod + def restore_weights_file(model, filepath, framework, task): + if task is None: + return filepath + + if not filepath: + getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored") + return filepath + + try: + WeightsFileHandler._model_store_lookup_lock.acquire() + + # check if object already has InputModel + trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None)) + if ref_model is not None and model != ref_model(): + # old id pop it - it was probably reused because the object is dead + WeightsFileHandler._model_in_store_lookup.pop(id(model)) + trains_in_model, ref_model = None, None + + # check if object already has InputModel + model_name_id = getattr(model, 'name', '') + # noinspection PyBroadException + try: + config_text = None + config_dict = trains_in_model.config_dict if trains_in_model else None + except Exception: + config_dict = None + # noinspection PyBroadException + try: + config_text = trains_in_model.config_text if trains_in_model else None + except Exception: + config_text = None + trains_in_model = InputModel.import_model( + weights_url=filepath, + config_dict=config_dict, + config_text=config_text, + name=task.name + ' ' + model_name_id, + label_enumeration=task.get_labels_enumeration(), + framework=framework, + create_as_published=False, + ) + # noinspection PyBroadException + try: + ref_model = weakref.ref(model) + except Exception: + ref_model = None + WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) + # todo: support multiple models for the same task + task.connect(trains_in_model) + # if we are running remotely we should deserialize the object + # because someone might have changed the config_dict + if running_remotely(): + # reload the model + model_config = trains_in_model.config_dict + # verify that this is the same model so we are not deserializing a diff model + if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and + config_dict.get('config').get('name') == model_config.get('config').get('name')) or \ + (not config_dict and not model_config): + filepath = trains_in_model.get_weights() + # update filepath to point to downloaded weights file + # actual model weights loading will be done outside the try/exception block + except Exception as ex: + getLogger(TrainsFrameworkAdapter).warning(str(ex)) + finally: + WeightsFileHandler._model_store_lookup_lock.release() + + return filepath + + @staticmethod + def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): + if task is None: + return saved_path + + try: + WeightsFileHandler._model_store_lookup_lock.acquire() + + # check if object already has InputModel + trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) + if ref_model is not None and model != ref_model(): + # old id pop it - it was probably reused because the object is dead + WeightsFileHandler._model_out_store_lookup.pop(id(model)) + trains_out_model, ref_model = None, None + + # check if object already has InputModel + if trains_out_model is None: + trains_out_model = OutputModel( + task=task, + # config_dict=config, + name=(task.name + ' - ' + model_name) if model_name else None, + label_enumeration=task.get_labels_enumeration(), + framework=framework, ) + # noinspection PyBroadException + try: + ref_model = weakref.ref(model) + except Exception: + ref_model = None + WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) + + if not saved_path: + getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ") + return saved_path + + # check if we have output storage, and generate list of files to upload + if trains_out_model.upload_storage_uri: + if Path(saved_path).is_dir(): + files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] + elif singlefile: + files = [str(Path(saved_path).absolute())] + else: + files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')] + else: + files = None + + # upload files if we found them, or just register the original path + if files: + if len(files) > 1: + # noinspection PyBroadException + try: + target_filename = Path(saved_path).stem + except Exception: + target_filename = None + trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, + target_filename=target_filename) + else: + trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False) + else: + trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) + except Exception as ex: + getLogger(TrainsFrameworkAdapter).warning(str(ex)) + finally: + WeightsFileHandler._model_store_lookup_lock.release() + + return saved_path diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py new file mode 100644 index 00000000..c382e2ab --- /dev/null +++ b/trains/binding/frameworks/pytorch_bind.py @@ -0,0 +1,101 @@ +import sys + +import six +from pathlib2 import Path + +from ..frameworks import _patched_call, WeightsFileHandler, _Empty +from ..import_bind import PostImportHookPatching +from ...config import running_remotely +from ...model import Framework + + +class PatchPyTorchModelIO(object): + __main_task = None + __patched = None + + @staticmethod + def update_current_task(task, **_): + PatchPyTorchModelIO.__main_task = task + PatchPyTorchModelIO._patch_model_io() + PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) + + @staticmethod + def _patch_model_io(): + if PatchPyTorchModelIO.__patched: + return + + if 'torch' not in sys.modules: + return + + PatchPyTorchModelIO.__patched = True + + # noinspection PyBroadException + try: + # hack: make sure tensorflow.__init__ is called + import torch + torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) + torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) + except ImportError: + pass + except Exception: + pass # print('Failed patching pytorch') + + @staticmethod + def _save(original_fn, obj, f, *args, **kwargs): + ret = original_fn(obj, f, *args, **kwargs) + if not PatchPyTorchModelIO.__main_task: + return ret + + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'name'): + filename = f.name + # noinspection PyBroadException + try: + f.flush() + except Exception: + pass + else: + filename = None + + # give the model a descriptive name based on the file name + # noinspection PyBroadException + try: + model_name = Path(filename).stem + except Exception: + model_name = None + WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, + singlefile=True, model_name=model_name) + return ret + + @staticmethod + def _load(original_fn, f, *args, **kwargs): + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'name'): + filename = f.name + else: + filename = None + + if not PatchPyTorchModelIO.__main_task: + return original_fn(f, *args, **kwargs) + + # register input model + empty = _Empty() + if running_remotely(): + filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, + PatchPyTorchModelIO.__main_task) + model = original_fn(filename or f, *args, **kwargs) + else: + # try to load model before registering, in case we fail + model = original_fn(filename or f, *args, **kwargs) + WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, + PatchPyTorchModelIO.__main_task) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + return model diff --git a/trains/utilities/frameworks.py b/trains/binding/frameworks/tensorflow_bind.py similarity index 85% rename from trains/utilities/frameworks.py rename to trains/binding/frameworks/tensorflow_bind.py index 3da630bc..19d9616a 100644 --- a/trains/utilities/frameworks.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1,256 +1,25 @@ import base64 import sys import threading -import weakref from collections import defaultdict from logging import ERROR, WARNING, getLogger -from pathlib2 import Path +from typing import Any import cv2 import numpy as np import six +from pathlib2 import Path -from ..config import running_remotely -from ..model import InputModel, OutputModel, Framework +from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter +from ..import_bind import PostImportHookPatching +from ...config import running_remotely +from ...model import InputModel, OutputModel, Framework try: from google.protobuf.json_format import MessageToDict except ImportError: MessageToDict = None -if six.PY2: - # python2.x - import __builtin__ as builtins -else: - # python3.x - import builtins - - -TrainsFrameworkAdapter = 'TrainsFrameworkAdapter' -_recursion_guard = {} - - -class _Empty(object): - def __init__(self): - self.trains_in_model = None - - -class PostImportHookPatching(object): - _patched = False - _post_import_hooks = defaultdict(list) - - @staticmethod - def _init_hook(): - if PostImportHookPatching._patched: - return - PostImportHookPatching._patched = True - - if six.PY2: - # python2.x - builtins.__org_import__ = builtins.__import__ - builtins.__import__ = PostImportHookPatching._patched_import2 - else: - # python3.x - builtins.__org_import__ = builtins.__import__ - builtins.__import__ = PostImportHookPatching._patched_import3 - - @staticmethod - def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1): - already_imported = name in sys.modules - mod = builtins.__org_import__( - name, - globals=globals, - locals=locals, - fromlist=fromlist, - level=level) - - if not already_imported and name in PostImportHookPatching._post_import_hooks: - for hook in PostImportHookPatching._post_import_hooks[name]: - hook() - return mod - - @staticmethod - def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0): - already_imported = name in sys.modules - mod = builtins.__org_import__( - name, - globals=globals, - locals=locals, - fromlist=fromlist, - level=level) - - if not already_imported and name in PostImportHookPatching._post_import_hooks: - for hook in PostImportHookPatching._post_import_hooks[name]: - hook() - return mod - - @staticmethod - def add_on_import(name, func): - PostImportHookPatching._init_hook() - if not name in PostImportHookPatching._post_import_hooks or \ - func not in PostImportHookPatching._post_import_hooks[name]: - PostImportHookPatching._post_import_hooks[name].append(func) - - @staticmethod - def remove_on_import(name, func): - if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]: - PostImportHookPatching._post_import_hooks[name].remove(func) - - -def _patched_call(original_fn, patched_fn): - def _inner_patch(*args, **kwargs): - ident = threading._get_ident() if six.PY2 else threading.get_ident() - if ident in _recursion_guard: - return original_fn(*args, **kwargs) - _recursion_guard[ident] = 1 - ret = None - try: - ret = patched_fn(original_fn, *args, **kwargs) - except Exception as ex: - raise ex - finally: - try: - _recursion_guard.pop(ident) - except KeyError: - pass - return ret - return _inner_patch - - -class WeightsFileHandler(object): - _model_out_store_lookup = {} - _model_in_store_lookup = {} - _model_store_lookup_lock = threading.Lock() - - @staticmethod - def restore_weights_file(model, filepath, framework, task): - if task is None: - return filepath - - if not filepath: - getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored") - return filepath - - try: - WeightsFileHandler._model_store_lookup_lock.acquire() - - # check if object already has InputModel - trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None)) - if ref_model is not None and model != ref_model(): - # old id pop it - it was probably reused because the object is dead - WeightsFileHandler._model_in_store_lookup.pop(id(model)) - trains_in_model, ref_model = None, None - - # check if object already has InputModel - model_name_id = getattr(model, 'name', '') - try: - config_text = None - config_dict = trains_in_model.config_dict if trains_in_model else None - except Exception: - config_dict = None - try: - config_text = trains_in_model.config_text if trains_in_model else None - except Exception: - config_text = None - trains_in_model = InputModel.import_model( - weights_url=filepath, - config_dict=config_dict, - config_text=config_text, - name=task.name + ' ' + model_name_id, - label_enumeration=task.get_labels_enumeration(), - framework=framework, - create_as_published=False, - ) - try: - ref_model = weakref.ref(model) - except Exception: - ref_model = None - WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) - # todo: support multiple models for the same task - task.connect(trains_in_model) - # if we are running remotely we should deserialize the object - # because someone might have changed the config_dict - if running_remotely(): - # reload the model - model_config = trains_in_model.config_dict - # verify that this is the same model so we are not deserializing a diff model - if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and - config_dict.get('config').get('name') == model_config.get('config').get('name')) or \ - (not config_dict and not model_config): - filepath = trains_in_model.get_weights() - # update filepath to point to downloaded weights file - # actual model weights loading will be done outside the try/exception block - except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) - finally: - WeightsFileHandler._model_store_lookup_lock.release() - - return filepath - - @staticmethod - def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): - if task is None: - return saved_path - - try: - WeightsFileHandler._model_store_lookup_lock.acquire() - - # check if object already has InputModel - trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) - if ref_model is not None and model != ref_model(): - # old id pop it - it was probably reused because the object is dead - WeightsFileHandler._model_out_store_lookup.pop(id(model)) - trains_out_model, ref_model = None, None - - # check if object already has InputModel - if trains_out_model is None: - trains_out_model = OutputModel( - task=task, - # config_dict=config, - name=(task.name + ' - ' + model_name) if model_name else None, - label_enumeration=task.get_labels_enumeration(), - framework=framework,) - try: - ref_model = weakref.ref(model) - except Exception: - ref_model = None - WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) - - if not saved_path: - getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ") - return saved_path - - # check if we have output storage, and generate list of files to upload - if trains_out_model.upload_storage_uri: - if Path(saved_path).is_dir(): - files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] - elif singlefile: - files = [str(Path(saved_path).absolute())] - else: - files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')] - else: - files = None - - # upload files if we found them, or just register the original path - if files: - if len(files) > 1: - try: - target_filename = Path(saved_path).stem - except Exception: - target_filename = None - trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, - target_filename=target_filename) - else: - trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False) - else: - trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) - except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) - finally: - WeightsFileHandler._model_store_lookup_lock.release() - - return saved_path - class EventTrainsWriter(object): """ @@ -271,7 +40,7 @@ class EventTrainsWriter(object): def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): """ Split a tf.summary tag line to variant and metric. - Variant is the first part of the splitted tag, metric is the second. + Variant is the first part of the split tag, metric is the second. :param str tag: :param int num_split_parts: :param str split_char: a character to split the tag on @@ -313,6 +82,7 @@ class EventTrainsWriter(object): self._max_step = 0 def _decode_image(self, img_str, width, height, color_channels): + # noinspection PyBroadException try: image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8) image = cv2.imdecode(image_string, cv2.IMREAD_COLOR) @@ -345,7 +115,7 @@ class EventTrainsWriter(object): title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images') if img_data_np.dtype != np.uint8: # assume scale 0-1 - img_data_np = (img_data_np*255).astype(np.uint8) + img_data_np = (img_data_np * 255).astype(np.uint8) # if 3d, pack into one big image if img_data_np.ndim == 4: @@ -433,7 +203,7 @@ class EventTrainsWriter(object): hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) # resample data so we are always constrained in number of histogram we keep - if hist_iters.size >= self.histogram_granularity**2: + if hist_iters.size >= self.histogram_granularity ** 2: idx = _sample_histograms(hist_iters, self.histogram_granularity) hist_iters = hist_iters[idx] hist_list = [hist_list[i] for i in idx] @@ -464,7 +234,7 @@ class EventTrainsWriter(object): # resample histograms on a unified bin axis _minmax = minmax[0] - 1, minmax[1] + 1 prev_xedge = np.arange(start=_minmax[0], - step=(_minmax[1]-_minmax[0])/(self._hist_x_granularity-2), stop=_minmax[1]) + step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1]) # uniformly select histograms and the last one cur_idx = _sample_histograms(hist_iters, self.histogram_granularity) report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32) @@ -495,6 +265,7 @@ class EventTrainsWriter(object): camera=(-0.1, +1.3, 1.4)) def _add_plot(self, tag, step, values, vdict): + # noinspection PyBroadException try: plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')), dtype=np.float32) @@ -506,7 +277,7 @@ class EventTrainsWriter(object): vdict['metadata']['pluginData']['pluginName'])] else: # this should not happen, maybe it's another run, let increase the value - self._series_name_lookup[tag] += [(tag+'_%d' % len(self._series_name_lookup[tag])+1, + self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1, vdict['metadata']['displayName'], vdict['metadata']['pluginData']['pluginName'])] @@ -749,7 +520,8 @@ class PatchSummaryToEventTransformer(object): # only patch once if PatchSummaryToEventTransformer.__original_getattributeX is None: from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX - PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__ + PatchSummaryToEventTransformer.__original_getattributeX = \ + SummaryToEventTransformerX.__getattribute__ SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX setattr(SummaryToEventTransformerX, 'trains', property(PatchSummaryToEventTransformer.trains_object)) @@ -779,7 +551,8 @@ class PatchSummaryToEventTransformer(object): return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) if not self.trains: self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + **PatchSummaryToEventTransformer.defaults_dict) + # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) except Exception: @@ -792,7 +565,8 @@ class PatchSummaryToEventTransformer(object): return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) if not self.trains: self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + **PatchSummaryToEventTransformer.defaults_dict) + # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) except Exception: @@ -1077,7 +851,7 @@ class PatchKerasModelIO(object): PatchKerasModelIO.__patched_keras = [ Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None, Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None, - keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,] + keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, ] else: PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras) @@ -1106,7 +880,7 @@ class PatchKerasModelIO(object): PatchKerasModelIO.__patched_tensorflow = [ Network if PatchKerasModelIO.__patched_keras[0] != Network else None, Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None, - keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,] + keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, ] else: PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving] PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow) @@ -1313,6 +1087,7 @@ class PatchKerasModelIO(object): WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task) # update the input model object if empty.trains_in_model: + # noinspection PyBroadException try: model.trains_in_model = empty.trains_in_model except Exception: @@ -1340,15 +1115,17 @@ class PatchTensorflowModelIO(object): return PatchTensorflowModelIO.__patched = True - + # noinspection PyBroadException try: # hack: make sure tensorflow.__init__ is called import tensorflow from tensorflow.python.training.saver import Saver + # noinspection PyBroadException try: Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save) except Exception: pass + # noinspection PyBroadException try: Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore) except Exception: @@ -1358,6 +1135,7 @@ class PatchTensorflowModelIO(object): except Exception: pass # print('Failed patching tensorflow') + # noinspection PyBroadException try: # make sure we import the correct version of save import tensorflow @@ -1365,6 +1143,7 @@ class PatchTensorflowModelIO(object): # actual import import tensorflow.saved_model.experimental as saved_model except ImportError: + # noinspection PyBroadException try: # make sure we import the correct version of save import tensorflow @@ -1383,6 +1162,7 @@ class PatchTensorflowModelIO(object): if saved_model is not None: saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model) + # noinspection PyBroadException try: # make sure we import the correct version of save import tensorflow @@ -1395,6 +1175,7 @@ class PatchTensorflowModelIO(object): except Exception: pass # print('Failed patching tensorflow') + # noinspection PyBroadException try: # make sure we import the correct version of save import tensorflow @@ -1406,6 +1187,7 @@ class PatchTensorflowModelIO(object): except Exception: pass # print('Failed patching tensorflow') + # noinspection PyBroadException try: # make sure we import the correct version of save import tensorflow @@ -1417,17 +1199,21 @@ class PatchTensorflowModelIO(object): except Exception: pass # print('Failed patching tensorflow') + # noinspection PyBroadException try: import tensorflow from tensorflow.train import Checkpoint + # noinspection PyBroadException try: Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save) except Exception: pass + # noinspection PyBroadException try: Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore) except Exception: pass + # noinspection PyBroadException try: Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write) except Exception: @@ -1447,8 +1233,8 @@ class PatchTensorflowModelIO(object): PatchTensorflowModelIO.__main_task) @staticmethod - def _save_model(original_fn, obj, export_dir, *args, **kwargs): - original_fn(obj, export_dir, *args, **kwargs) + def _save_model(original_fn, obj, export_dir, *args, **kwargs): + original_fn(obj, export_dir, *args, **kwargs) # store output Model WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow, PatchTensorflowModelIO.__main_task) @@ -1490,6 +1276,7 @@ class PatchTensorflowModelIO(object): PatchTensorflowModelIO.__main_task) if empty.trains_in_model: + # noinspection PyBroadException try: model.trains_in_model = empty.trains_in_model except Exception: @@ -1532,6 +1319,7 @@ class PatchTensorflowModelIO(object): PatchTensorflowModelIO.__main_task) if empty.trains_in_model: + # noinspection PyBroadException try: model.trains_in_model = empty.trains_in_model except Exception: @@ -1558,7 +1346,7 @@ class PatchPyTorchModelIO(object): return PatchPyTorchModelIO.__patched = True - + # noinspection PyBroadException try: # hack: make sure tensorflow.__init__ is called import torch @@ -1579,6 +1367,7 @@ class PatchPyTorchModelIO(object): filename = f elif hasattr(f, 'name'): filename = f.name + # noinspection PyBroadException try: f.flush() except Exception: @@ -1586,7 +1375,8 @@ class PatchPyTorchModelIO(object): else: filename = None - # if the model a screptive name based on the file name + # give the model a descriptive name based on the file name + # noinspection PyBroadException try: model_name = Path(filename).stem except Exception: @@ -1620,6 +1410,7 @@ class PatchPyTorchModelIO(object): PatchPyTorchModelIO.__main_task) if empty.trains_in_model: + # noinspection PyBroadException try: model.trains_in_model = empty.trains_in_model except Exception: diff --git a/trains/binding/import_bind.py b/trains/binding/import_bind.py new file mode 100644 index 00000000..53ac5fb5 --- /dev/null +++ b/trains/binding/import_bind.py @@ -0,0 +1,73 @@ +import sys +from collections import defaultdict +import six + +if six.PY2: + # python2.x + import __builtin__ as builtins +else: + # python3.x + import builtins + + +class PostImportHookPatching(object): + _patched = False + _post_import_hooks = defaultdict(list) + + @staticmethod + def _init_hook(): + if PostImportHookPatching._patched: + return + PostImportHookPatching._patched = True + + if six.PY2: + # python2.x + builtins.__org_import__ = builtins.__import__ + builtins.__import__ = PostImportHookPatching._patched_import2 + else: + # python3.x + builtins.__org_import__ = builtins.__import__ + builtins.__import__ = PostImportHookPatching._patched_import3 + + @staticmethod + def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1): + already_imported = name in sys.modules + mod = builtins.__org_import__( + name, + globals=globals, + locals=locals, + fromlist=fromlist, + level=level) + + if not already_imported and name in PostImportHookPatching._post_import_hooks: + for hook in PostImportHookPatching._post_import_hooks[name]: + hook() + return mod + + @staticmethod + def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0): + already_imported = name in sys.modules + mod = builtins.__org_import__( + name, + globals=globals, + locals=locals, + fromlist=fromlist, + level=level) + + if not already_imported and name in PostImportHookPatching._post_import_hooks: + for hook in PostImportHookPatching._post_import_hooks[name]: + hook() + return mod + + @staticmethod + def add_on_import(name, func): + PostImportHookPatching._init_hook() + if not name in PostImportHookPatching._post_import_hooks or \ + func not in PostImportHookPatching._post_import_hooks[name]: + PostImportHookPatching._post_import_hooks[name].append(func) + + @staticmethod + def remove_on_import(name, func): + if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]: + PostImportHookPatching._post_import_hooks[name].remove(func) + diff --git a/trains/utilities/matplotlib_bind.py b/trains/binding/matplotlib_bind.py similarity index 85% rename from trains/utilities/matplotlib_bind.py rename to trains/binding/matplotlib_bind.py index 2d4e6391..626fd3a8 100644 --- a/trains/utilities/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -12,6 +12,8 @@ from ..config import running_remotely class PatchedMatplotlib: _patched_original_plot = None __patched_original_imshow = None + __patched_original_draw_all = None + __patched_draw_all_recursion_guard = False _global_plot_counter = -1 _global_image_counter = -1 _current_task = None @@ -45,18 +47,18 @@ class PatchedMatplotlib: if running_remotely(): # disable GUI backend - make headless - sys.modules['matplotlib'].rcParams['backend'] = 'agg' + matplotlib.rcParams['backend'] = 'agg' import matplotlib.pyplot - sys.modules['matplotlib'].pyplot.switch_backend('agg') + matplotlib.pyplot.switch_backend('agg') import matplotlib.pyplot as plt from matplotlib import _pylab_helpers if six.PY2: - PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show) - PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow) + PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) + PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) else: - PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show - PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow - sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show + PatchedMatplotlib._patched_original_plot = plt.show + PatchedMatplotlib._patched_original_imshow = plt.imshow + plt.show = PatchedMatplotlib.patched_show # sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow # patch plotly so we know it failed us. from plotly.matplotlylib import renderer @@ -71,7 +73,11 @@ class PatchedMatplotlib: from IPython import get_ipython ip = get_ipython() if ip and matplotlib.is_interactive(): - ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook) + # instead of hooking ipython, we should hook the matplotlib + import matplotlib.pyplot as plt + PatchedMatplotlib.__patched_original_draw_all = plt.draw_all + plt.draw_all = PatchedMatplotlib.__patched_draw_all + # ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook) except Exception: pass @@ -188,6 +194,19 @@ class PatchedMatplotlib: return + @staticmethod + def __patched_draw_all(*args, **kwargs): + recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard + if not recursion_guard: + PatchedMatplotlib.__patched_draw_all_recursion_guard = True + + ret = PatchedMatplotlib.__patched_original_draw_all(*args, **kwargs) + + if not recursion_guard: + PatchedMatplotlib.ipython_post_execute_hook() + PatchedMatplotlib.__patched_draw_all_recursion_guard = False + return ret + @staticmethod def ipython_post_execute_hook(): # noinspection PyBroadException diff --git a/trains/task.py b/trains/task.py index d407bde9..02675c2c 100644 --- a/trains/task.py +++ b/trains/task.py @@ -27,12 +27,13 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters -from .utilities.absl_bind import PatchAbsl +from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask -from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \ - PatchTensorflowModelIO, PatchPyTorchModelIO -from .utilities.matplotlib_bind import PatchedMatplotlib +from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO +from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \ + PatchKerasModelIO, PatchTensorflowModelIO +from .binding.matplotlib_bind import PatchedMatplotlib from .utilities.seed import make_deterministic NotSet = object()