Refactored binding, better support for matplotlib jupyter binding

This commit is contained in:
allegroai 2019-06-20 01:50:40 +03:00
parent ff8652f39f
commit a77b470500
7 changed files with 428 additions and 265 deletions

View File

@ -1,5 +1,5 @@
""" absl-py FLAGS binding utility functions """
from trains.backend_interface.task.args import _Arguments
from ..backend_interface.task.args import _Arguments
from ..config import running_remotely

View File

@ -0,0 +1,178 @@
import threading
import weakref
from logging import getLogger
import six
from pathlib2 import Path
from ...config import running_remotely
from ...model import InputModel, OutputModel
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
ret = None
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:
raise ex
finally:
try:
_recursion_guard.pop(ident)
except KeyError:
pass
return ret
return _inner_patch
class _Empty(object):
def __init__(self):
self.trains_in_model = None
class WeightsFileHandler(object):
_model_out_store_lookup = {}
_model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock()
@staticmethod
def restore_weights_file(model, filepath, framework, task):
if task is None:
return filepath
if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
return filepath
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
# noinspection PyBroadException
try:
config_text = None
config_dict = trains_in_model.config_dict if trains_in_model else None
except Exception:
config_dict = None
# noinspection PyBroadException
try:
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
if task is None:
return saved_path
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
# check if object already has InputModel
if trains_out_model is None:
trains_out_model = OutputModel(
task=task,
# config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework, )
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
else:
files = None
# upload files if we found them, or just register the original path
if files:
if len(files) > 1:
# noinspection PyBroadException
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename)
else:
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return saved_path

View File

@ -0,0 +1,101 @@
import sys
import six
from pathlib2 import Path
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import Framework
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **_):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model

View File

@ -1,256 +1,25 @@
import base64
import sys
import threading
import weakref
from collections import defaultdict
from logging import ERROR, WARNING, getLogger
from pathlib2 import Path
from typing import Any
import cv2
import numpy as np
import six
from pathlib2 import Path
from ..config import running_remotely
from ..model import InputModel, OutputModel, Framework
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework
try:
from google.protobuf.json_format import MessageToDict
except ImportError:
MessageToDict = None
if six.PY2:
# python2.x
import __builtin__ as builtins
else:
# python3.x
import builtins
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
class _Empty(object):
def __init__(self):
self.trains_in_model = None
class PostImportHookPatching(object):
_patched = False
_post_import_hooks = defaultdict(list)
@staticmethod
def _init_hook():
if PostImportHookPatching._patched:
return
PostImportHookPatching._patched = True
if six.PY2:
# python2.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import2
else:
# python3.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import3
@staticmethod
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)
@staticmethod
def remove_on_import(name, func):
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].remove(func)
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
ret = None
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:
raise ex
finally:
try:
_recursion_guard.pop(ident)
except KeyError:
pass
return ret
return _inner_patch
class WeightsFileHandler(object):
_model_out_store_lookup = {}
_model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock()
@staticmethod
def restore_weights_file(model, filepath, framework, task):
if task is None:
return filepath
if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
return filepath
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
try:
config_text = None
config_dict = trains_in_model.config_dict if trains_in_model else None
except Exception:
config_dict = None
try:
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
if task is None:
return saved_path
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
# check if object already has InputModel
if trains_out_model is None:
trains_out_model = OutputModel(
task=task,
# config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework,)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')]
else:
files = None
# upload files if we found them, or just register the original path
if files:
if len(files) > 1:
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename)
else:
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return saved_path
class EventTrainsWriter(object):
"""
@ -271,7 +40,7 @@ class EventTrainsWriter(object):
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
"""
Split a tf.summary tag line to variant and metric.
Variant is the first part of the splitted tag, metric is the second.
Variant is the first part of the split tag, metric is the second.
:param str tag:
:param int num_split_parts:
:param str split_char: a character to split the tag on
@ -313,6 +82,7 @@ class EventTrainsWriter(object):
self._max_step = 0
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
try:
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
@ -345,7 +115,7 @@ class EventTrainsWriter(object):
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np*255).astype(np.uint8)
img_data_np = (img_data_np * 255).astype(np.uint8)
# if 3d, pack into one big image
if img_data_np.ndim == 4:
@ -433,7 +203,7 @@ class EventTrainsWriter(object):
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
# resample data so we are always constrained in number of histogram we keep
if hist_iters.size >= self.histogram_granularity**2:
if hist_iters.size >= self.histogram_granularity ** 2:
idx = _sample_histograms(hist_iters, self.histogram_granularity)
hist_iters = hist_iters[idx]
hist_list = [hist_list[i] for i in idx]
@ -464,7 +234,7 @@ class EventTrainsWriter(object):
# resample histograms on a unified bin axis
_minmax = minmax[0] - 1, minmax[1] + 1
prev_xedge = np.arange(start=_minmax[0],
step=(_minmax[1]-_minmax[0])/(self._hist_x_granularity-2), stop=_minmax[1])
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
# uniformly select histograms and the last one
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
@ -495,6 +265,7 @@ class EventTrainsWriter(object):
camera=(-0.1, +1.3, 1.4))
def _add_plot(self, tag, step, values, vdict):
# noinspection PyBroadException
try:
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
dtype=np.float32)
@ -506,7 +277,7 @@ class EventTrainsWriter(object):
vdict['metadata']['pluginData']['pluginName'])]
else:
# this should not happen, maybe it's another run, let increase the value
self._series_name_lookup[tag] += [(tag+'_%d' % len(self._series_name_lookup[tag])+1,
self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1,
vdict['metadata']['displayName'],
vdict['metadata']['pluginData']['pluginName'])]
@ -749,7 +520,8 @@ class PatchSummaryToEventTransformer(object):
# only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None:
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__
PatchSummaryToEventTransformer.__original_getattributeX = \
SummaryToEventTransformerX.__getattribute__
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
setattr(SummaryToEventTransformerX, 'trains',
property(PatchSummaryToEventTransformer.trains_object))
@ -779,7 +551,8 @@ class PatchSummaryToEventTransformer(object):
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
**PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
except Exception:
@ -792,7 +565,8 @@ class PatchSummaryToEventTransformer(object):
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
**PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
except Exception:
@ -1077,7 +851,7 @@ class PatchKerasModelIO(object):
PatchKerasModelIO.__patched_keras = [
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,]
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, ]
else:
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
@ -1106,7 +880,7 @@ class PatchKerasModelIO(object):
PatchKerasModelIO.__patched_tensorflow = [
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,]
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, ]
else:
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving]
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
@ -1313,6 +1087,7 @@ class PatchKerasModelIO(object):
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
# update the input model object
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
@ -1340,15 +1115,17 @@ class PatchTensorflowModelIO(object):
return
PatchTensorflowModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.saver import Saver
# noinspection PyBroadException
try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
except Exception:
@ -1358,6 +1135,7 @@ class PatchTensorflowModelIO(object):
except Exception:
pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
@ -1365,6 +1143,7 @@ class PatchTensorflowModelIO(object):
# actual import
import tensorflow.saved_model.experimental as saved_model
except ImportError:
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
@ -1383,6 +1162,7 @@ class PatchTensorflowModelIO(object):
if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
@ -1395,6 +1175,7 @@ class PatchTensorflowModelIO(object):
except Exception:
pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
@ -1406,6 +1187,7 @@ class PatchTensorflowModelIO(object):
except Exception:
pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
@ -1417,17 +1199,21 @@ class PatchTensorflowModelIO(object):
except Exception:
pass # print('Failed patching tensorflow')
# noinspection PyBroadException
try:
import tensorflow
from tensorflow.train import Checkpoint
# noinspection PyBroadException
try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
except Exception:
pass
# noinspection PyBroadException
try:
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
except Exception:
pass
# noinspection PyBroadException
try:
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
except Exception:
@ -1447,8 +1233,8 @@ class PatchTensorflowModelIO(object):
PatchTensorflowModelIO.__main_task)
@staticmethod
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
original_fn(obj, export_dir, *args, **kwargs)
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
original_fn(obj, export_dir, *args, **kwargs)
# store output Model
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
@ -1490,6 +1276,7 @@ class PatchTensorflowModelIO(object):
PatchTensorflowModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
@ -1532,6 +1319,7 @@ class PatchTensorflowModelIO(object):
PatchTensorflowModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
@ -1558,7 +1346,7 @@ class PatchPyTorchModelIO(object):
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
@ -1579,6 +1367,7 @@ class PatchPyTorchModelIO(object):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
@ -1586,7 +1375,8 @@ class PatchPyTorchModelIO(object):
else:
filename = None
# if the model a screptive name based on the file name
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
@ -1620,6 +1410,7 @@ class PatchPyTorchModelIO(object):
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:

View File

@ -0,0 +1,73 @@
import sys
from collections import defaultdict
import six
if six.PY2:
# python2.x
import __builtin__ as builtins
else:
# python3.x
import builtins
class PostImportHookPatching(object):
_patched = False
_post_import_hooks = defaultdict(list)
@staticmethod
def _init_hook():
if PostImportHookPatching._patched:
return
PostImportHookPatching._patched = True
if six.PY2:
# python2.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import2
else:
# python3.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import3
@staticmethod
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)
@staticmethod
def remove_on_import(name, func):
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].remove(func)

View File

@ -12,6 +12,8 @@ from ..config import running_remotely
class PatchedMatplotlib:
_patched_original_plot = None
__patched_original_imshow = None
__patched_original_draw_all = None
__patched_draw_all_recursion_guard = False
_global_plot_counter = -1
_global_image_counter = -1
_current_task = None
@ -45,18 +47,18 @@ class PatchedMatplotlib:
if running_remotely():
# disable GUI backend - make headless
sys.modules['matplotlib'].rcParams['backend'] = 'agg'
matplotlib.rcParams['backend'] = 'agg'
import matplotlib.pyplot
sys.modules['matplotlib'].pyplot.switch_backend('agg')
matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers
if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow)
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
else:
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show
PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow
plt.show = PatchedMatplotlib.patched_show
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
# patch plotly so we know it failed us.
from plotly.matplotlylib import renderer
@ -71,7 +73,11 @@ class PatchedMatplotlib:
from IPython import get_ipython
ip = get_ipython()
if ip and matplotlib.is_interactive():
ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
# instead of hooking ipython, we should hook the matplotlib
import matplotlib.pyplot as plt
PatchedMatplotlib.__patched_original_draw_all = plt.draw_all
plt.draw_all = PatchedMatplotlib.__patched_draw_all
# ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
except Exception:
pass
@ -188,6 +194,19 @@ class PatchedMatplotlib:
return
@staticmethod
def __patched_draw_all(*args, **kwargs):
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
if not recursion_guard:
PatchedMatplotlib.__patched_draw_all_recursion_guard = True
ret = PatchedMatplotlib.__patched_original_draw_all(*args, **kwargs)
if not recursion_guard:
PatchedMatplotlib.ipython_post_execute_hook()
PatchedMatplotlib.__patched_draw_all_recursion_guard = False
return ret
@staticmethod
def ipython_post_execute_hook():
# noinspection PyBroadException

View File

@ -27,12 +27,13 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .utilities.absl_bind import PatchAbsl
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \
PatchTensorflowModelIO, PatchPyTorchModelIO
from .utilities.matplotlib_bind import PatchedMatplotlib
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
PatchKerasModelIO, PatchTensorflowModelIO
from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.seed import make_deterministic
NotSet = object()