Refactored binding, better support for matplotlib jupyter binding

This commit is contained in:
allegroai 2019-06-20 01:50:40 +03:00
parent ff8652f39f
commit a77b470500
7 changed files with 428 additions and 265 deletions

View File

@ -1,5 +1,5 @@
""" absl-py FLAGS binding utility functions """ """ absl-py FLAGS binding utility functions """
from trains.backend_interface.task.args import _Arguments from ..backend_interface.task.args import _Arguments
from ..config import running_remotely from ..config import running_remotely

View File

@ -0,0 +1,178 @@
import threading
import weakref
from logging import getLogger
import six
from pathlib2 import Path
from ...config import running_remotely
from ...model import InputModel, OutputModel
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
ret = None
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:
raise ex
finally:
try:
_recursion_guard.pop(ident)
except KeyError:
pass
return ret
return _inner_patch
class _Empty(object):
def __init__(self):
self.trains_in_model = None
class WeightsFileHandler(object):
_model_out_store_lookup = {}
_model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock()
@staticmethod
def restore_weights_file(model, filepath, framework, task):
if task is None:
return filepath
if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
return filepath
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
# noinspection PyBroadException
try:
config_text = None
config_dict = trains_in_model.config_dict if trains_in_model else None
except Exception:
config_dict = None
# noinspection PyBroadException
try:
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
if task is None:
return saved_path
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
# check if object already has InputModel
if trains_out_model is None:
trains_out_model = OutputModel(
task=task,
# config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework, )
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
else:
files = None
# upload files if we found them, or just register the original path
if files:
if len(files) > 1:
# noinspection PyBroadException
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename)
else:
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return saved_path

View File

@ -0,0 +1,101 @@
import sys
import six
from pathlib2 import Path
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import Framework
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **_):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model

View File

@ -1,256 +1,25 @@
import base64 import base64
import sys import sys
import threading import threading
import weakref
from collections import defaultdict from collections import defaultdict
from logging import ERROR, WARNING, getLogger from logging import ERROR, WARNING, getLogger
from pathlib2 import Path from typing import Any
import cv2 import cv2
import numpy as np import numpy as np
import six import six
from pathlib2 import Path
from ..config import running_remotely from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
from ..model import InputModel, OutputModel, Framework from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework
try: try:
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
except ImportError: except ImportError:
MessageToDict = None MessageToDict = None
if six.PY2:
# python2.x
import __builtin__ as builtins
else:
# python3.x
import builtins
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
class _Empty(object):
def __init__(self):
self.trains_in_model = None
class PostImportHookPatching(object):
_patched = False
_post_import_hooks = defaultdict(list)
@staticmethod
def _init_hook():
if PostImportHookPatching._patched:
return
PostImportHookPatching._patched = True
if six.PY2:
# python2.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import2
else:
# python3.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import3
@staticmethod
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)
@staticmethod
def remove_on_import(name, func):
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].remove(func)
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
ret = None
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:
raise ex
finally:
try:
_recursion_guard.pop(ident)
except KeyError:
pass
return ret
return _inner_patch
class WeightsFileHandler(object):
_model_out_store_lookup = {}
_model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock()
@staticmethod
def restore_weights_file(model, filepath, framework, task):
if task is None:
return filepath
if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
return filepath
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
try:
config_text = None
config_dict = trains_in_model.config_dict if trains_in_model else None
except Exception:
config_dict = None
try:
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
if task is None:
return saved_path
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
# check if object already has InputModel
if trains_out_model is None:
trains_out_model = OutputModel(
task=task,
# config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework,)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')]
else:
files = None
# upload files if we found them, or just register the original path
if files:
if len(files) > 1:
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename)
else:
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return saved_path
class EventTrainsWriter(object): class EventTrainsWriter(object):
""" """
@ -271,7 +40,7 @@ class EventTrainsWriter(object):
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
""" """
Split a tf.summary tag line to variant and metric. Split a tf.summary tag line to variant and metric.
Variant is the first part of the splitted tag, metric is the second. Variant is the first part of the split tag, metric is the second.
:param str tag: :param str tag:
:param int num_split_parts: :param int num_split_parts:
:param str split_char: a character to split the tag on :param str split_char: a character to split the tag on
@ -313,6 +82,7 @@ class EventTrainsWriter(object):
self._max_step = 0 self._max_step = 0
def _decode_image(self, img_str, width, height, color_channels): def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
try: try:
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8) image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR) image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
@ -495,6 +265,7 @@ class EventTrainsWriter(object):
camera=(-0.1, +1.3, 1.4)) camera=(-0.1, +1.3, 1.4))
def _add_plot(self, tag, step, values, vdict): def _add_plot(self, tag, step, values, vdict):
# noinspection PyBroadException
try: try:
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')), plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
dtype=np.float32) dtype=np.float32)
@ -749,7 +520,8 @@ class PatchSummaryToEventTransformer(object):
# only patch once # only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None: if PatchSummaryToEventTransformer.__original_getattributeX is None:
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__ PatchSummaryToEventTransformer.__original_getattributeX = \
SummaryToEventTransformerX.__getattribute__
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
setattr(SummaryToEventTransformerX, 'trains', setattr(SummaryToEventTransformerX, 'trains',
property(PatchSummaryToEventTransformer.trains_object)) property(PatchSummaryToEventTransformer.trains_object))
@ -780,6 +552,7 @@ class PatchSummaryToEventTransformer(object):
if not self.trains: if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
except Exception: except Exception:
@ -793,6 +566,7 @@ class PatchSummaryToEventTransformer(object):
if not self.trains: if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
except Exception: except Exception:
@ -1313,6 +1087,7 @@ class PatchKerasModelIO(object):
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task) WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
# update the input model object # update the input model object
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException
try: try:
model.trains_in_model = empty.trains_in_model model.trains_in_model = empty.trains_in_model
except Exception: except Exception:
@ -1340,15 +1115,17 @@ class PatchTensorflowModelIO(object):
return return
PatchTensorflowModelIO.__patched = True PatchTensorflowModelIO.__patched = True
# noinspection PyBroadException
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import tensorflow import tensorflow
from tensorflow.python.training.saver import Saver from tensorflow.python.training.saver import Saver
# noinspection PyBroadException
try: try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save) Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
except Exception: except Exception:
pass pass
# noinspection PyBroadException
try: try:
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore) Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
except Exception: except Exception:
@ -1358,6 +1135,7 @@ class PatchTensorflowModelIO(object):
except Exception: except Exception:
pass # print('Failed patching tensorflow') pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
@ -1365,6 +1143,7 @@ class PatchTensorflowModelIO(object):
# actual import # actual import
import tensorflow.saved_model.experimental as saved_model import tensorflow.saved_model.experimental as saved_model
except ImportError: except ImportError:
# noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
@ -1383,6 +1162,7 @@ class PatchTensorflowModelIO(object):
if saved_model is not None: if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model) saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
# noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
@ -1395,6 +1175,7 @@ class PatchTensorflowModelIO(object):
except Exception: except Exception:
pass # print('Failed patching tensorflow') pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
@ -1406,6 +1187,7 @@ class PatchTensorflowModelIO(object):
except Exception: except Exception:
pass # print('Failed patching tensorflow') pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
@ -1417,17 +1199,21 @@ class PatchTensorflowModelIO(object):
except Exception: except Exception:
pass # print('Failed patching tensorflow') pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try: try:
import tensorflow import tensorflow
from tensorflow.train import Checkpoint from tensorflow.train import Checkpoint
# noinspection PyBroadException
try: try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save) Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
except Exception: except Exception:
pass pass
# noinspection PyBroadException
try: try:
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore) Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
except Exception: except Exception:
pass pass
# noinspection PyBroadException
try: try:
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write) Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
except Exception: except Exception:
@ -1490,6 +1276,7 @@ class PatchTensorflowModelIO(object):
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO.__main_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException
try: try:
model.trains_in_model = empty.trains_in_model model.trains_in_model = empty.trains_in_model
except Exception: except Exception:
@ -1532,6 +1319,7 @@ class PatchTensorflowModelIO(object):
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO.__main_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException
try: try:
model.trains_in_model = empty.trains_in_model model.trains_in_model = empty.trains_in_model
except Exception: except Exception:
@ -1558,7 +1346,7 @@ class PatchPyTorchModelIO(object):
return return
PatchPyTorchModelIO.__patched = True PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try: try:
# hack: make sure tensorflow.__init__ is called # hack: make sure tensorflow.__init__ is called
import torch import torch
@ -1579,6 +1367,7 @@ class PatchPyTorchModelIO(object):
filename = f filename = f
elif hasattr(f, 'name'): elif hasattr(f, 'name'):
filename = f.name filename = f.name
# noinspection PyBroadException
try: try:
f.flush() f.flush()
except Exception: except Exception:
@ -1586,7 +1375,8 @@ class PatchPyTorchModelIO(object):
else: else:
filename = None filename = None
# if the model a screptive name based on the file name # give the model a descriptive name based on the file name
# noinspection PyBroadException
try: try:
model_name = Path(filename).stem model_name = Path(filename).stem
except Exception: except Exception:
@ -1620,6 +1410,7 @@ class PatchPyTorchModelIO(object):
PatchPyTorchModelIO.__main_task) PatchPyTorchModelIO.__main_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException
try: try:
model.trains_in_model = empty.trains_in_model model.trains_in_model = empty.trains_in_model
except Exception: except Exception:

View File

@ -0,0 +1,73 @@
import sys
from collections import defaultdict
import six
if six.PY2:
# python2.x
import __builtin__ as builtins
else:
# python3.x
import builtins
class PostImportHookPatching(object):
_patched = False
_post_import_hooks = defaultdict(list)
@staticmethod
def _init_hook():
if PostImportHookPatching._patched:
return
PostImportHookPatching._patched = True
if six.PY2:
# python2.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import2
else:
# python3.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import3
@staticmethod
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)
@staticmethod
def remove_on_import(name, func):
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].remove(func)

View File

@ -12,6 +12,8 @@ from ..config import running_remotely
class PatchedMatplotlib: class PatchedMatplotlib:
_patched_original_plot = None _patched_original_plot = None
__patched_original_imshow = None __patched_original_imshow = None
__patched_original_draw_all = None
__patched_draw_all_recursion_guard = False
_global_plot_counter = -1 _global_plot_counter = -1
_global_image_counter = -1 _global_image_counter = -1
_current_task = None _current_task = None
@ -45,18 +47,18 @@ class PatchedMatplotlib:
if running_remotely(): if running_remotely():
# disable GUI backend - make headless # disable GUI backend - make headless
sys.modules['matplotlib'].rcParams['backend'] = 'agg' matplotlib.rcParams['backend'] = 'agg'
import matplotlib.pyplot import matplotlib.pyplot
sys.modules['matplotlib'].pyplot.switch_backend('agg') matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
if six.PY2: if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show) PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
else: else:
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow PatchedMatplotlib._patched_original_imshow = plt.imshow
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show plt.show = PatchedMatplotlib.patched_show
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow # sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
# patch plotly so we know it failed us. # patch plotly so we know it failed us.
from plotly.matplotlylib import renderer from plotly.matplotlylib import renderer
@ -71,7 +73,11 @@ class PatchedMatplotlib:
from IPython import get_ipython from IPython import get_ipython
ip = get_ipython() ip = get_ipython()
if ip and matplotlib.is_interactive(): if ip and matplotlib.is_interactive():
ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook) # instead of hooking ipython, we should hook the matplotlib
import matplotlib.pyplot as plt
PatchedMatplotlib.__patched_original_draw_all = plt.draw_all
plt.draw_all = PatchedMatplotlib.__patched_draw_all
# ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
except Exception: except Exception:
pass pass
@ -188,6 +194,19 @@ class PatchedMatplotlib:
return return
@staticmethod
def __patched_draw_all(*args, **kwargs):
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
if not recursion_guard:
PatchedMatplotlib.__patched_draw_all_recursion_guard = True
ret = PatchedMatplotlib.__patched_original_draw_all(*args, **kwargs)
if not recursion_guard:
PatchedMatplotlib.ipython_post_execute_hook()
PatchedMatplotlib.__patched_draw_all_recursion_guard = False
return ret
@staticmethod @staticmethod
def ipython_post_execute_hook(): def ipython_post_execute_hook():
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -27,12 +27,13 @@ from .errors import UsageError
from .logger import Logger from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters from .task_parameters import TaskParameters
from .utilities.absl_bind import PatchAbsl from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask argparser_update_currenttask
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \ from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
PatchTensorflowModelIO, PatchPyTorchModelIO from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
from .utilities.matplotlib_bind import PatchedMatplotlib PatchKerasModelIO, PatchTensorflowModelIO
from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
NotSet = object() NotSet = object()