mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Refactored binding, better support for matplotlib jupyter binding
This commit is contained in:
parent
ff8652f39f
commit
a77b470500
@ -1,5 +1,5 @@
|
||||
""" absl-py FLAGS binding utility functions """
|
||||
from trains.backend_interface.task.args import _Arguments
|
||||
from ..backend_interface.task.args import _Arguments
|
||||
from ..config import running_remotely
|
||||
|
||||
|
178
trains/binding/frameworks/__init__.py
Normal file
178
trains/binding/frameworks/__init__.py
Normal file
@ -0,0 +1,178 @@
|
||||
import threading
|
||||
import weakref
|
||||
from logging import getLogger
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel
|
||||
|
||||
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
||||
_recursion_guard = {}
|
||||
|
||||
|
||||
def _patched_call(original_fn, patched_fn):
|
||||
def _inner_patch(*args, **kwargs):
|
||||
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
if ident in _recursion_guard:
|
||||
return original_fn(*args, **kwargs)
|
||||
_recursion_guard[ident] = 1
|
||||
ret = None
|
||||
try:
|
||||
ret = patched_fn(original_fn, *args, **kwargs)
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
finally:
|
||||
try:
|
||||
_recursion_guard.pop(ident)
|
||||
except KeyError:
|
||||
pass
|
||||
return ret
|
||||
|
||||
return _inner_patch
|
||||
|
||||
|
||||
class _Empty(object):
|
||||
def __init__(self):
|
||||
self.trains_in_model = None
|
||||
|
||||
|
||||
class WeightsFileHandler(object):
|
||||
_model_out_store_lookup = {}
|
||||
_model_in_store_lookup = {}
|
||||
_model_store_lookup_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def restore_weights_file(model, filepath, framework, task):
|
||||
if task is None:
|
||||
return filepath
|
||||
|
||||
if not filepath:
|
||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
|
||||
return filepath
|
||||
|
||||
try:
|
||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||
|
||||
# check if object already has InputModel
|
||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||
trains_in_model, ref_model = None, None
|
||||
|
||||
# check if object already has InputModel
|
||||
model_name_id = getattr(model, 'name', '')
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
config_text = None
|
||||
config_dict = trains_in_model.config_dict if trains_in_model else None
|
||||
except Exception:
|
||||
config_dict = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
config_text = trains_in_model.config_text if trains_in_model else None
|
||||
except Exception:
|
||||
config_text = None
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=filepath,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + ' ' + model_name_id,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
except Exception:
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
|
||||
# todo: support multiple models for the same task
|
||||
task.connect(trains_in_model)
|
||||
# if we are running remotely we should deserialize the object
|
||||
# because someone might have changed the config_dict
|
||||
if running_remotely():
|
||||
# reload the model
|
||||
model_config = trains_in_model.config_dict
|
||||
# verify that this is the same model so we are not deserializing a diff model
|
||||
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
||||
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
||||
(not config_dict and not model_config):
|
||||
filepath = trains_in_model.get_weights()
|
||||
# update filepath to point to downloaded weights file
|
||||
# actual model weights loading will be done outside the try/exception block
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
finally:
|
||||
WeightsFileHandler._model_store_lookup_lock.release()
|
||||
|
||||
return filepath
|
||||
|
||||
@staticmethod
|
||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||
if task is None:
|
||||
return saved_path
|
||||
|
||||
try:
|
||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||
|
||||
# check if object already has InputModel
|
||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||
trains_out_model, ref_model = None, None
|
||||
|
||||
# check if object already has InputModel
|
||||
if trains_out_model is None:
|
||||
trains_out_model = OutputModel(
|
||||
task=task,
|
||||
# config_dict=config,
|
||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework, )
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
except Exception:
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||
|
||||
if not saved_path:
|
||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
|
||||
return saved_path
|
||||
|
||||
# check if we have output storage, and generate list of files to upload
|
||||
if trains_out_model.upload_storage_uri:
|
||||
if Path(saved_path).is_dir():
|
||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
||||
elif singlefile:
|
||||
files = [str(Path(saved_path).absolute())]
|
||||
else:
|
||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
||||
else:
|
||||
files = None
|
||||
|
||||
# upload files if we found them, or just register the original path
|
||||
if files:
|
||||
if len(files) > 1:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
target_filename = Path(saved_path).stem
|
||||
except Exception:
|
||||
target_filename = None
|
||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
||||
target_filename=target_filename)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
finally:
|
||||
WeightsFileHandler._model_store_lookup_lock.release()
|
||||
|
||||
return saved_path
|
101
trains/binding/frameworks/pytorch_bind.py
Normal file
101
trains/binding/frameworks/pytorch_bind.py
Normal file
@ -0,0 +1,101 @@
|
||||
import sys
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import Framework
|
||||
|
||||
|
||||
class PatchPyTorchModelIO(object):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchPyTorchModelIO.__main_task = task
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchPyTorchModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'torch' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import torch
|
||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching pytorch')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
@ -1,256 +1,25 @@
|
||||
import base64
|
||||
import sys
|
||||
import threading
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from logging import ERROR, WARNING, getLogger
|
||||
from pathlib2 import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..config import running_remotely
|
||||
from ..model import InputModel, OutputModel, Framework
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel, Framework
|
||||
|
||||
try:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
except ImportError:
|
||||
MessageToDict = None
|
||||
|
||||
if six.PY2:
|
||||
# python2.x
|
||||
import __builtin__ as builtins
|
||||
else:
|
||||
# python3.x
|
||||
import builtins
|
||||
|
||||
|
||||
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
||||
_recursion_guard = {}
|
||||
|
||||
|
||||
class _Empty(object):
|
||||
def __init__(self):
|
||||
self.trains_in_model = None
|
||||
|
||||
|
||||
class PostImportHookPatching(object):
|
||||
_patched = False
|
||||
_post_import_hooks = defaultdict(list)
|
||||
|
||||
@staticmethod
|
||||
def _init_hook():
|
||||
if PostImportHookPatching._patched:
|
||||
return
|
||||
PostImportHookPatching._patched = True
|
||||
|
||||
if six.PY2:
|
||||
# python2.x
|
||||
builtins.__org_import__ = builtins.__import__
|
||||
builtins.__import__ = PostImportHookPatching._patched_import2
|
||||
else:
|
||||
# python3.x
|
||||
builtins.__org_import__ = builtins.__import__
|
||||
builtins.__import__ = PostImportHookPatching._patched_import3
|
||||
|
||||
@staticmethod
|
||||
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
|
||||
already_imported = name in sys.modules
|
||||
mod = builtins.__org_import__(
|
||||
name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level)
|
||||
|
||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||
hook()
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
already_imported = name in sys.modules
|
||||
mod = builtins.__org_import__(
|
||||
name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level)
|
||||
|
||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||
hook()
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
def add_on_import(name, func):
|
||||
PostImportHookPatching._init_hook()
|
||||
if not name in PostImportHookPatching._post_import_hooks or \
|
||||
func not in PostImportHookPatching._post_import_hooks[name]:
|
||||
PostImportHookPatching._post_import_hooks[name].append(func)
|
||||
|
||||
@staticmethod
|
||||
def remove_on_import(name, func):
|
||||
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
|
||||
PostImportHookPatching._post_import_hooks[name].remove(func)
|
||||
|
||||
|
||||
def _patched_call(original_fn, patched_fn):
|
||||
def _inner_patch(*args, **kwargs):
|
||||
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
if ident in _recursion_guard:
|
||||
return original_fn(*args, **kwargs)
|
||||
_recursion_guard[ident] = 1
|
||||
ret = None
|
||||
try:
|
||||
ret = patched_fn(original_fn, *args, **kwargs)
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
finally:
|
||||
try:
|
||||
_recursion_guard.pop(ident)
|
||||
except KeyError:
|
||||
pass
|
||||
return ret
|
||||
return _inner_patch
|
||||
|
||||
|
||||
class WeightsFileHandler(object):
|
||||
_model_out_store_lookup = {}
|
||||
_model_in_store_lookup = {}
|
||||
_model_store_lookup_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def restore_weights_file(model, filepath, framework, task):
|
||||
if task is None:
|
||||
return filepath
|
||||
|
||||
if not filepath:
|
||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
|
||||
return filepath
|
||||
|
||||
try:
|
||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||
|
||||
# check if object already has InputModel
|
||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||
trains_in_model, ref_model = None, None
|
||||
|
||||
# check if object already has InputModel
|
||||
model_name_id = getattr(model, 'name', '')
|
||||
try:
|
||||
config_text = None
|
||||
config_dict = trains_in_model.config_dict if trains_in_model else None
|
||||
except Exception:
|
||||
config_dict = None
|
||||
try:
|
||||
config_text = trains_in_model.config_text if trains_in_model else None
|
||||
except Exception:
|
||||
config_text = None
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=filepath,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + ' ' + model_name_id,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
except Exception:
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
|
||||
# todo: support multiple models for the same task
|
||||
task.connect(trains_in_model)
|
||||
# if we are running remotely we should deserialize the object
|
||||
# because someone might have changed the config_dict
|
||||
if running_remotely():
|
||||
# reload the model
|
||||
model_config = trains_in_model.config_dict
|
||||
# verify that this is the same model so we are not deserializing a diff model
|
||||
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
||||
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
||||
(not config_dict and not model_config):
|
||||
filepath = trains_in_model.get_weights()
|
||||
# update filepath to point to downloaded weights file
|
||||
# actual model weights loading will be done outside the try/exception block
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
finally:
|
||||
WeightsFileHandler._model_store_lookup_lock.release()
|
||||
|
||||
return filepath
|
||||
|
||||
@staticmethod
|
||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||
if task is None:
|
||||
return saved_path
|
||||
|
||||
try:
|
||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||
|
||||
# check if object already has InputModel
|
||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||
trains_out_model, ref_model = None, None
|
||||
|
||||
# check if object already has InputModel
|
||||
if trains_out_model is None:
|
||||
trains_out_model = OutputModel(
|
||||
task=task,
|
||||
# config_dict=config,
|
||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,)
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
except Exception:
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||
|
||||
if not saved_path:
|
||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
|
||||
return saved_path
|
||||
|
||||
# check if we have output storage, and generate list of files to upload
|
||||
if trains_out_model.upload_storage_uri:
|
||||
if Path(saved_path).is_dir():
|
||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
||||
elif singlefile:
|
||||
files = [str(Path(saved_path).absolute())]
|
||||
else:
|
||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')]
|
||||
else:
|
||||
files = None
|
||||
|
||||
# upload files if we found them, or just register the original path
|
||||
if files:
|
||||
if len(files) > 1:
|
||||
try:
|
||||
target_filename = Path(saved_path).stem
|
||||
except Exception:
|
||||
target_filename = None
|
||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
||||
target_filename=target_filename)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
finally:
|
||||
WeightsFileHandler._model_store_lookup_lock.release()
|
||||
|
||||
return saved_path
|
||||
|
||||
|
||||
class EventTrainsWriter(object):
|
||||
"""
|
||||
@ -271,7 +40,7 @@ class EventTrainsWriter(object):
|
||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
||||
"""
|
||||
Split a tf.summary tag line to variant and metric.
|
||||
Variant is the first part of the splitted tag, metric is the second.
|
||||
Variant is the first part of the split tag, metric is the second.
|
||||
:param str tag:
|
||||
:param int num_split_parts:
|
||||
:param str split_char: a character to split the tag on
|
||||
@ -313,6 +82,7 @@ class EventTrainsWriter(object):
|
||||
self._max_step = 0
|
||||
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
|
||||
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
|
||||
@ -345,7 +115,7 @@ class EventTrainsWriter(object):
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
|
||||
if img_data_np.dtype != np.uint8:
|
||||
# assume scale 0-1
|
||||
img_data_np = (img_data_np*255).astype(np.uint8)
|
||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||
|
||||
# if 3d, pack into one big image
|
||||
if img_data_np.ndim == 4:
|
||||
@ -433,7 +203,7 @@ class EventTrainsWriter(object):
|
||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||
|
||||
# resample data so we are always constrained in number of histogram we keep
|
||||
if hist_iters.size >= self.histogram_granularity**2:
|
||||
if hist_iters.size >= self.histogram_granularity ** 2:
|
||||
idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
||||
hist_iters = hist_iters[idx]
|
||||
hist_list = [hist_list[i] for i in idx]
|
||||
@ -464,7 +234,7 @@ class EventTrainsWriter(object):
|
||||
# resample histograms on a unified bin axis
|
||||
_minmax = minmax[0] - 1, minmax[1] + 1
|
||||
prev_xedge = np.arange(start=_minmax[0],
|
||||
step=(_minmax[1]-_minmax[0])/(self._hist_x_granularity-2), stop=_minmax[1])
|
||||
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
|
||||
# uniformly select histograms and the last one
|
||||
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
||||
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
|
||||
@ -495,6 +265,7 @@ class EventTrainsWriter(object):
|
||||
camera=(-0.1, +1.3, 1.4))
|
||||
|
||||
def _add_plot(self, tag, step, values, vdict):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
||||
dtype=np.float32)
|
||||
@ -506,7 +277,7 @@ class EventTrainsWriter(object):
|
||||
vdict['metadata']['pluginData']['pluginName'])]
|
||||
else:
|
||||
# this should not happen, maybe it's another run, let increase the value
|
||||
self._series_name_lookup[tag] += [(tag+'_%d' % len(self._series_name_lookup[tag])+1,
|
||||
self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1,
|
||||
vdict['metadata']['displayName'],
|
||||
vdict['metadata']['pluginData']['pluginName'])]
|
||||
|
||||
@ -749,7 +520,8 @@ class PatchSummaryToEventTransformer(object):
|
||||
# only patch once
|
||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
||||
PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__
|
||||
PatchSummaryToEventTransformer.__original_getattributeX = \
|
||||
SummaryToEventTransformerX.__getattribute__
|
||||
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
||||
setattr(SummaryToEventTransformerX, 'trains',
|
||||
property(PatchSummaryToEventTransformer.trains_object))
|
||||
@ -779,7 +551,8 @@ class PatchSummaryToEventTransformer(object):
|
||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
except Exception:
|
||||
@ -792,7 +565,8 @@ class PatchSummaryToEventTransformer(object):
|
||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
except Exception:
|
||||
@ -1077,7 +851,7 @@ class PatchKerasModelIO(object):
|
||||
PatchKerasModelIO.__patched_keras = [
|
||||
Network if PatchKerasModelIO.__patched_tensorflow[0] != Network else None,
|
||||
Sequential if PatchKerasModelIO.__patched_tensorflow[1] != Sequential else None,
|
||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None,]
|
||||
keras_saving if PatchKerasModelIO.__patched_tensorflow[2] != keras_saving else None, ]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_keras = [Network, Sequential, keras_saving]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_keras)
|
||||
@ -1106,7 +880,7 @@ class PatchKerasModelIO(object):
|
||||
PatchKerasModelIO.__patched_tensorflow = [
|
||||
Network if PatchKerasModelIO.__patched_keras[0] != Network else None,
|
||||
Sequential if PatchKerasModelIO.__patched_keras[1] != Sequential else None,
|
||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None,]
|
||||
keras_saving if PatchKerasModelIO.__patched_keras[2] != keras_saving else None, ]
|
||||
else:
|
||||
PatchKerasModelIO.__patched_tensorflow = [Network, Sequential, keras_saving]
|
||||
PatchKerasModelIO._patch_io_calls(*PatchKerasModelIO.__patched_tensorflow)
|
||||
@ -1313,6 +1087,7 @@ class PatchKerasModelIO(object):
|
||||
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
||||
# update the input model object
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
@ -1340,15 +1115,17 @@ class PatchTensorflowModelIO(object):
|
||||
return
|
||||
|
||||
PatchTensorflowModelIO.__patched = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow
|
||||
from tensorflow.python.training.saver import Saver
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
|
||||
except Exception:
|
||||
@ -1358,6 +1135,7 @@ class PatchTensorflowModelIO(object):
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
@ -1365,6 +1143,7 @@ class PatchTensorflowModelIO(object):
|
||||
# actual import
|
||||
import tensorflow.saved_model.experimental as saved_model
|
||||
except ImportError:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
@ -1383,6 +1162,7 @@ class PatchTensorflowModelIO(object):
|
||||
if saved_model is not None:
|
||||
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
@ -1395,6 +1175,7 @@ class PatchTensorflowModelIO(object):
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
@ -1406,6 +1187,7 @@ class PatchTensorflowModelIO(object):
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we import the correct version of save
|
||||
import tensorflow
|
||||
@ -1417,17 +1199,21 @@ class PatchTensorflowModelIO(object):
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import tensorflow
|
||||
from tensorflow.train import Checkpoint
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
|
||||
except Exception:
|
||||
@ -1447,8 +1233,8 @@ class PatchTensorflowModelIO(object):
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
|
||||
@staticmethod
|
||||
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
|
||||
original_fn(obj, export_dir, *args, **kwargs)
|
||||
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
|
||||
original_fn(obj, export_dir, *args, **kwargs)
|
||||
# store output Model
|
||||
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
@ -1490,6 +1276,7 @@ class PatchTensorflowModelIO(object):
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
@ -1532,6 +1319,7 @@ class PatchTensorflowModelIO(object):
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
@ -1558,7 +1346,7 @@ class PatchPyTorchModelIO(object):
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import torch
|
||||
@ -1579,6 +1367,7 @@ class PatchPyTorchModelIO(object):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
@ -1586,7 +1375,8 @@ class PatchPyTorchModelIO(object):
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# if the model a screptive name based on the file name
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
@ -1620,6 +1410,7 @@ class PatchPyTorchModelIO(object):
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
73
trains/binding/import_bind.py
Normal file
73
trains/binding/import_bind.py
Normal file
@ -0,0 +1,73 @@
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
import six
|
||||
|
||||
if six.PY2:
|
||||
# python2.x
|
||||
import __builtin__ as builtins
|
||||
else:
|
||||
# python3.x
|
||||
import builtins
|
||||
|
||||
|
||||
class PostImportHookPatching(object):
|
||||
_patched = False
|
||||
_post_import_hooks = defaultdict(list)
|
||||
|
||||
@staticmethod
|
||||
def _init_hook():
|
||||
if PostImportHookPatching._patched:
|
||||
return
|
||||
PostImportHookPatching._patched = True
|
||||
|
||||
if six.PY2:
|
||||
# python2.x
|
||||
builtins.__org_import__ = builtins.__import__
|
||||
builtins.__import__ = PostImportHookPatching._patched_import2
|
||||
else:
|
||||
# python3.x
|
||||
builtins.__org_import__ = builtins.__import__
|
||||
builtins.__import__ = PostImportHookPatching._patched_import3
|
||||
|
||||
@staticmethod
|
||||
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
|
||||
already_imported = name in sys.modules
|
||||
mod = builtins.__org_import__(
|
||||
name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level)
|
||||
|
||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||
hook()
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
already_imported = name in sys.modules
|
||||
mod = builtins.__org_import__(
|
||||
name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level)
|
||||
|
||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||
hook()
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
def add_on_import(name, func):
|
||||
PostImportHookPatching._init_hook()
|
||||
if not name in PostImportHookPatching._post_import_hooks or \
|
||||
func not in PostImportHookPatching._post_import_hooks[name]:
|
||||
PostImportHookPatching._post_import_hooks[name].append(func)
|
||||
|
||||
@staticmethod
|
||||
def remove_on_import(name, func):
|
||||
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
|
||||
PostImportHookPatching._post_import_hooks[name].remove(func)
|
||||
|
@ -12,6 +12,8 @@ from ..config import running_remotely
|
||||
class PatchedMatplotlib:
|
||||
_patched_original_plot = None
|
||||
__patched_original_imshow = None
|
||||
__patched_original_draw_all = None
|
||||
__patched_draw_all_recursion_guard = False
|
||||
_global_plot_counter = -1
|
||||
_global_image_counter = -1
|
||||
_current_task = None
|
||||
@ -45,18 +47,18 @@ class PatchedMatplotlib:
|
||||
|
||||
if running_remotely():
|
||||
# disable GUI backend - make headless
|
||||
sys.modules['matplotlib'].rcParams['backend'] = 'agg'
|
||||
matplotlib.rcParams['backend'] = 'agg'
|
||||
import matplotlib.pyplot
|
||||
sys.modules['matplotlib'].pyplot.switch_backend('agg')
|
||||
matplotlib.pyplot.switch_backend('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import _pylab_helpers
|
||||
if six.PY2:
|
||||
PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show)
|
||||
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow)
|
||||
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
|
||||
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
|
||||
else:
|
||||
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
|
||||
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
|
||||
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show
|
||||
PatchedMatplotlib._patched_original_plot = plt.show
|
||||
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||
plt.show = PatchedMatplotlib.patched_show
|
||||
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||
# patch plotly so we know it failed us.
|
||||
from plotly.matplotlylib import renderer
|
||||
@ -71,7 +73,11 @@ class PatchedMatplotlib:
|
||||
from IPython import get_ipython
|
||||
ip = get_ipython()
|
||||
if ip and matplotlib.is_interactive():
|
||||
ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
|
||||
# instead of hooking ipython, we should hook the matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
PatchedMatplotlib.__patched_original_draw_all = plt.draw_all
|
||||
plt.draw_all = PatchedMatplotlib.__patched_draw_all
|
||||
# ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -188,6 +194,19 @@ class PatchedMatplotlib:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def __patched_draw_all(*args, **kwargs):
|
||||
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
|
||||
if not recursion_guard:
|
||||
PatchedMatplotlib.__patched_draw_all_recursion_guard = True
|
||||
|
||||
ret = PatchedMatplotlib.__patched_original_draw_all(*args, **kwargs)
|
||||
|
||||
if not recursion_guard:
|
||||
PatchedMatplotlib.ipython_post_execute_hook()
|
||||
PatchedMatplotlib.__patched_draw_all_recursion_guard = False
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def ipython_post_execute_hook():
|
||||
# noinspection PyBroadException
|
@ -27,12 +27,13 @@ from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .utilities.absl_bind import PatchAbsl
|
||||
from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \
|
||||
PatchTensorflowModelIO, PatchPyTorchModelIO
|
||||
from .utilities.matplotlib_bind import PatchedMatplotlib
|
||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
||||
PatchKerasModelIO, PatchTensorflowModelIO
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .utilities.seed import make_deterministic
|
||||
|
||||
NotSet = object()
|
||||
|
Loading…
Reference in New Issue
Block a user