clearml/trains/utilities/frameworks.py

1612 lines
68 KiB
Python
Raw Normal View History

2019-06-10 17:00:28 +00:00
import base64
import sys
import threading
import weakref
from collections import defaultdict
from logging import ERROR, WARNING, getLogger
from pathlib import Path
import cv2
import numpy as np
import six
from ..config import running_remotely
from ..model import InputModel, OutputModel, Framework
try:
from google.protobuf.json_format import MessageToDict
except ImportError:
MessageToDict = None
if six.PY2:
# python2.x
import __builtin__ as builtins
else:
# python3.x
import builtins
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
class _Empty(object):
def __init__(self):
self.trains_in_model = None
class PostImportHookPatching(object):
_patched = False
_post_import_hooks = defaultdict(list)
@staticmethod
def _init_hook():
if PostImportHookPatching._patched:
return
PostImportHookPatching._patched = True
if six.PY2:
# python2.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import2
else:
# python3.x
builtins.__org_import__ = builtins.__import__
builtins.__import__ = PostImportHookPatching._patched_import3
@staticmethod
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook()
return mod
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)
@staticmethod
def remove_on_import(name, func):
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].remove(func)
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:
raise ex
finally:
try:
_recursion_guard.pop(ident)
except KeyError:
pass
return ret
return _inner_patch
class WeightsFileHandler(object):
_model_out_store_lookup = {}
_model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock()
@staticmethod
def restore_weights_file(model, filepath, framework, task):
if task is None:
return filepath
if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
return filepath
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
try:
config_text = None
config_dict = trains_in_model.config_dict if trains_in_model else None
except Exception:
config_dict = None
try:
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return filepath
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
if task is None:
return saved_path
try:
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
# check if object already has InputModel
if trains_out_model is None:
trains_out_model = OutputModel(
task=task,
# config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(),
framework=framework,)
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')]
else:
files = None
# upload files if we found them, or just register the original path
if files:
if len(files) > 1:
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename)
else:
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
finally:
WeightsFileHandler._model_store_lookup_lock.release()
return saved_path
class EventTrainsWriter(object):
"""
TF SummaryWriter implementation that converts the tensorboard's summary into
Trains events and reports the events (metrics) for an Trains task (logger).
"""
_add_lock = threading.Lock()
_series_name_lookup = {}
@property
def variants(self):
return self._variants
def prepare_report(self):
return self.variants.copy()
@staticmethod
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
"""
Split a tf.summary tag line to variant and metric.
Variant is the first part of the splitted tag, metric is the second.
:param str tag:
:param int num_split_parts:
:param str split_char: a character to split the tag on
:param str join_char: a character to join the the splits
:param str default_title: variant to use in case no variant can be inferred automatically
:return: (str, str) variant and metric
"""
splitted_tag = tag.split(split_char)
series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
return title, series
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
histogram_granularity=50, max_keep_images=None):
"""
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
:param logger: The task.logger to use for sending the metrics (def: task.get_logger())
:param report_freq: How often to update the statistics values
:param image_report_freq: How often to upload images (step % image_update_freq == 0)
:param histogram_update_freq_multiplier: How often to upload histogram
(step//update_freq) % histogram_update_freq_multiplier == 0
:param histogram_granularity: How many histograms (lines) to display in the 3d histogram plot
:param max_keep_images: Maximum number of images to save before starting to reuse files (per title/metric pair)
"""
# We are the events_writer, so that's what we'll pass
self.max_keep_images = max_keep_images
self.report_freq = report_freq
self.image_report_freq = image_report_freq if image_report_freq else report_freq
self.histogram_granularity = histogram_granularity
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
self._logger = logger
self._visualization_mode = 'BGR'
self._variants = defaultdict(lambda: ())
self._scalar_report_cache = {}
self._hist_report_cache = {}
self._hist_x_granularity = 50
self._max_step = 0
def _decode_image(self, img_str, width, height, color_channels):
try:
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
val = image.reshape(height, width, -1).astype(np.uint8)
if val.ndim == 3 and val.shape[2] == 3:
if self._visualization_mode == 'BGR':
val = val[:, :, [2, 1, 0]]
else:
val = val
elif (val.ndim == 2) or (val.ndim == 3 and val.shape[2] == 1):
val = np.tile(np.atleast_3d(val), (1, 1, 3))
elif val.ndim == 3 and val.shape[2] == 4:
if self._visualization_mode == 'BGR':
val = val[:, :, [2, 1, 0]]
else:
val = val[:, :, [0, 1, 2]]
except Exception:
self._logger.warning('Failed decoding debug image [%d, %d, %d]' % (width, height, color_channels))
val = None
return val
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
if img_data_np is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np*255).astype(np.uint8)
# if 3d, pack into one big image
if img_data_np.ndim == 4:
dims = img_data_np.shape
stack_dim = int(np.sqrt(dims[0]))
res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4))
tile_size = res.shape[0] * res.shape[1]
img_data_np = res.reshape(tile_size, tile_size, -1)
self._logger.report_image_and_upload(
title=title,
series=series,
iteration=step,
matrix=img_data_np,
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
)
def _add_image(self, tag, step, img_data):
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
width = img_data['width']
height = img_data['height']
colorspace = img_data['colorspace']
img_str = img_data['encodedImageString']
matrix = self._decode_image(img_str, width=width, height=height, color_channels=colorspace)
if matrix is None:
return
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
def _add_scalar(self, tag, step, scalar_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
# update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0))
self._scalar_report_cache[(title, series)] = (num + 1, value + scalar_data)
# only report images every specific interval
if step % self.report_freq != 0:
return None
# calculate mean and zero cache
num, value = self._scalar_report_cache.get((title, series), (0, 0))
scalar_data = value / num
self._scalar_report_cache[(title, series)] = (0, 0)
self._logger.report_scalar(
title=title,
series=series,
iteration=step,
value=scalar_data,
)
def _add_histogram(self, tag, step, histo_data):
def _sample_histograms(_hist_iters, _histogram_granularity):
# resample history based on distribution of samples across time (steps)
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
_hist_iters.size > _histogram_granularity else 0.
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
np.random.shuffle(cur_idx_below)
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
if ratio > 0.0:
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
np.random.shuffle(cur_idx_above)
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
else:
cur_idx_above = np.array([])
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
return _cur_idx
# only collect histogram every specific interval
if step % self.report_freq != 0 or step < self.report_freq - 1:
return None
# generate forward matrix of the histograms
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
# get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
# resample data so we are always constrained in number of histogram we keep
if hist_iters.size >= self.histogram_granularity**2:
idx = _sample_histograms(hist_iters, self.histogram_granularity)
hist_iters = hist_iters[idx]
hist_list = [hist_list[i] for i in idx]
# check if current sample is not already here (actually happens some times)
if step in hist_iters:
return None
# add current sample, if not already here
hist_iters = np.append(hist_iters, step)
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist)
# keep track of min/max values of histograms (for later re-binning)
if minmax is None:
minmax = hist[:, 0].min(), hist[:, 0].max()
else:
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
# update the cache
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
# only report histogram every specific interval, but do report the first few, so you know there are histograms
if hist_iters.size < 1 or (hist_iters.size >= self.histogram_update_freq_multiplier and
hist_iters.size % self.histogram_update_freq_multiplier != 0):
return None
# resample histograms on a unified bin axis
_minmax = minmax[0] - 1, minmax[1] + 1
prev_xedge = np.arange(start=_minmax[0],
step=(_minmax[1]-_minmax[0])/(self._hist_x_granularity-2), stop=_minmax[1])
# uniformly select histograms and the last one
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
for i, n in enumerate(cur_idx):
h = hist_list[n]
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
yedges = hist_iters[cur_idx]
xedges = prev_xedge
# if only a single line make, add another zero line, for the scatter plot to draw
if report_hist.shape[0] < 2:
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
# create 3d line (scatter) of histograms
skipx = max(1, int(xedges.size / 10))
skipy = max(1, int(yedges.size / 10))
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
self._logger.report_surface(
title=title,
series=series,
iteration=0,
xtitle=' ',
ytitle='iteration',
xlabels=xlabels,
ylabels=ylabels,
matrix=report_hist,
camera=(-0.1, +1.3, 1.4))
def _add_plot(self, tag, step, values, vdict):
try:
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
dtype=np.float32)
plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']),
int(values['tensorShape']['dim'][1]['size'])))
if 'metadata' in vdict:
if tag not in self._series_name_lookup:
self._series_name_lookup[tag] = [(tag, vdict['metadata']['displayName'],
vdict['metadata']['pluginData']['pluginName'])]
else:
# this should not happen, maybe it's another run, let increase the value
self._series_name_lookup[tag] += [(tag+'_%d' % len(self._series_name_lookup[tag])+1,
vdict['metadata']['displayName'],
vdict['metadata']['pluginData']['pluginName'])]
tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1]
if 'pr_curve' in plugin_name:
# our thresholds are evenly distributed, in that
# width = 1.0 / (num_thresholds - 1)
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
num_thresholds = plot_values.shape[1]
width = 1.0 / (num_thresholds - 1)
thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype)
data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall']
series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T,
'labels': [''.join(data_points) + '<br> ' +
' '.join(['%-3.2f' % v for v in plot_values[:, j]]) for j in
range(plot_values.shape[1])]}]
reverse_xaxis = True
else:
reverse_xaxis = False
series = [{'name': series, 'data': plot_values}]
self._logger.report_line_plot(title=tag, series=series, xaxis='', yaxis='',
iteration=step, reverse_xaxis=reverse_xaxis)
except Exception:
pass
def add_event(self, event, step=None, walltime=None, **kwargs):
supported_metrics = {
'simpleValue', 'image', 'histo', 'tensor'
}
def get_data(value_dict, metric_search_order):
data = None
metric_type = 'Unsupported'
for variant in metric_search_order:
data = value_dict.get(variant)
if data is not None:
metric_type = variant
break
return metric_type, data
# Support multiple threads accessing this instance (i.e. let TF/Keras do what they need)
with self._add_lock:
# TODO: add report frequency threshold (i.e. if we are sending too much data, increase the report_freq)
# we should measure reports per second and throttle back the reporting details accordingly
msg_dict = MessageToDict(event)
summary = msg_dict.get('summary')
if summary is None:
msg_dict.pop('step', None)
msg_dict.pop('wallTime', None)
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
keys_list = ', '.join(keys_list)
self._logger.debug('event summary not found, message type unsupported: %s' % keys_list)
return
value_dicts = summary.get('value')
walltime = walltime or msg_dict.get('step')
step = step or msg_dict.get('step')
if step is None:
# when we start a new epoch there is no step in the msg_dict,
# we have to extract it manually
if hasattr(event, 'step'):
step = int(event.step)
else:
step = 0
self._logger.debug('Recieved event without step, assuming step = {}'.format(step), WARNING)
else:
step = int(step)
self._max_step = max(self._max_step, step)
if value_dicts is None:
self._logger.debug("Summary with arrived without 'value'", ERROR)
return
for vdict in value_dicts:
tag = vdict.pop('tag', None)
if tag is None:
# we should not get here
self._logger.debug('No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
continue
metric, values = get_data(vdict, supported_metrics)
if metric == 'simpleValue':
self._add_scalar(tag=tag, step=step, scalar_data=values)
elif metric == 'histo':
self._add_histogram(tag=tag, step=step, histo_data=values)
elif metric == 'image':
self._add_image(tag=tag, step=step, img_data=values)
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
# text, just print to console
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8')
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict)
else:
self._logger.debug('Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys)))
continue
def get_logdir(self):
""" Returns a temporary directory name for compatibility with FileWriter. This directory is not actually used.
:return: '.'
"""
return '.'
def flush(self):
"""Flushes the event file to disk.
Call this method to make sure that all pending events have been written to
disk.
"""
self._logger.flush()
def close(self):
"""Flushes the event file to disk and close the file.
Call this method when you do not need the summary writer anymore.
"""
self._logger.flush()
def reopen(self):
"""Reopens the EventFileWriter.
Can be called after `close()` to add more events in the same directory.
The events will go into a new events file.
Does nothing if the EventFileWriter was not closed.
"""
pass
class ProxyEventsWriter(object):
def __init__(self, events):
self._events = events
def _get_sentinel_event(self):
ret = None
for ev in self._events:
if hasattr(ev, '_get_sentinel_event'):
ret = ev._get_sentinel_event()
return ret
def get_logdir(self):
ret = None
for ev in self._events:
if hasattr(ev, 'get_logdir'):
ret = ev.get_logdir()
return ret
def reopen(self):
ret = None
for ev in self._events:
if hasattr(ev, 'reopen'):
ret = ev.reopen()
return ret
def add_event(self, *args, **kwargs):
ret = None
for ev in self._events:
if hasattr(ev, 'add_event'):
ret = ev.add_event(*args, **kwargs)
return ret
def flush(self):
ret = None
for ev in self._events:
if hasattr(ev, 'flush'):
ret = ev.flush()
return ret
def close(self):
ret = None
for ev in self._events:
if hasattr(ev, 'close'):
ret = ev.close()
return ret
class PatchSummaryToEventTransformer(object):
__main_task = None
__original_getattribute = None
__original_getattributeX = None
_original_add_event = None
_original_add_eventT = None
_original_add_eventX = None
defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50)
@staticmethod
def trains_object(self):
if isinstance(self.event_writer, ProxyEventsWriter):
trains_writer = [e for e in self.event_writer._events if isinstance(e, EventTrainsWriter)]
return trains_writer[0] if trains_writer else None
elif isinstance(self.event_writer, EventTrainsWriter):
return self.event_writer
if not self.__dict__.get('_trains_defaults'):
self.__dict__['_trains_defaults'] = {}
return self.__dict__['_trains_defaults']
@staticmethod
def update_current_task(task, **kwargs):
PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
PatchSummaryToEventTransformer.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
PostImportHookPatching.add_on_import('tensorflow',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('torch',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('tensorboardX',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
@staticmethod
def _patch_summary_to_event_transformer():
if 'tensorflow' in sys.modules:
try:
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer
# only patch once
if PatchSummaryToEventTransformer.__original_getattribute is None:
PatchSummaryToEventTransformer.__original_getattribute = SummaryToEventTransformer.__getattribute__
SummaryToEventTransformer.__getattribute__ = PatchSummaryToEventTransformer._patched_getattribute
setattr(SummaryToEventTransformer, 'trains',
property(PatchSummaryToEventTransformer.trains_object))
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
if 'torch' in sys.modules:
try:
# only patch once
if PatchSummaryToEventTransformer._original_add_eventT is None:
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
setattr(FileWriterT, 'trains', None)
except ImportError:
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
if 'tensorboardX' in sys.modules:
try:
# only patch once
if PatchSummaryToEventTransformer.__original_getattributeX is None:
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
setattr(SummaryToEventTransformerX, 'trains',
property(PatchSummaryToEventTransformer.trains_object))
except ImportError:
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
if PatchSummaryToEventTransformer.__original_getattributeX is None:
try:
# only patch once
if PatchSummaryToEventTransformer._original_add_eventX is None:
from tensorboardX.writer import FileWriter as FileWriterX
PatchSummaryToEventTransformer._original_add_eventX = FileWriterX.add_event
FileWriterX.add_event = PatchSummaryToEventTransformer._patched_add_eventX
setattr(FileWriterX, 'trains', None)
except ImportError:
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _patched_add_eventT(self, *args, **kwargs):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
try:
self.trains.add_event(*args, **kwargs)
except Exception:
pass
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
@staticmethod
def _patched_add_eventX(self, *args, **kwargs):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains:
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
try:
self.trains.add_event(*args, **kwargs)
except Exception:
pass
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
@staticmethod
def _patched_getattribute(self, attr):
get_base = PatchSummaryToEventTransformer.__original_getattribute
return PatchSummaryToEventTransformer._patched_getattribute_(self, attr, get_base)
@staticmethod
def _patched_getattributeX(self, attr):
get_base = PatchSummaryToEventTransformer.__original_getattributeX
return PatchSummaryToEventTransformer._patched_getattribute_(self, attr, get_base)
@staticmethod
def _patched_getattribute_(self, attr, get_base):
# no main task, zero chance we have an Trains event logger
if PatchSummaryToEventTransformer.__main_task is None:
return get_base(self, attr)
# check if we already have an Trains event logger
__dict__ = get_base(self, '__dict__')
if 'event_writer' not in __dict__ or \
isinstance(__dict__['event_writer'], (ProxyEventsWriter, EventTrainsWriter)):
return get_base(self, attr)
# patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer']
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
return get_base(self, attr)
class _ModelAdapter(object):
""" Model adapter which extends the save and save_weights methods of a Keras Model instance """
_model = None # type: Any
_output_model = None # type: OutputModel
def __init__(self, model, output_model):
super(_ModelAdapter, self).__init__()
super(_ModelAdapter, self).__setattr__('_model', model)
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
super(_ModelAdapter, self).__setattr__('_logger', getLogger('TrainsModelAdapter'))
def __getattr__(self, attr):
return getattr(self._model, attr)
def __setattr__(self, key, value):
return setattr(self._model, key, value)
def save(self, filepath, overwrite=True, include_optimizer=True):
self._model.save(filepath=filepath, overwrite=overwrite, include_optimizer=include_optimizer)
# TODO: auto generate new objects of filename changes
try:
self._output_model.update_weights(weights_filename=filepath, auto_delete_file=True)
except Exception as ex:
self._logger.error(str(ex))
def save_weights(self, filepath, overwrite=True):
self._model.save_weights(filepath=filepath, overwrite=overwrite)
# TODO: auto generate new objects of filename changes
try:
self._output_model.update_weights(weights_filename=filepath, auto_delete_file=True)
except Exception as ex:
self._logger.error(str(ex))
class PatchModelCheckPointCallback(object):
__main_task = None
__original_getattribute = None
defaults_dict = dict(
config_text=None,
config_dict=None,
label_enumeration=None,
name=None,
comment=None)
@staticmethod
def trains_object(self):
if isinstance(self.model, _ModelAdapter):
return self.model._output_model
if not self.__dict__.get('_trains_defaults'):
self.__dict__['_trains_defaults'] = {}
return self.__dict__['_trains_defaults']
@staticmethod
def update_current_task(task, **kwargs):
PatchModelCheckPointCallback.defaults_dict.update(kwargs)
PatchModelCheckPointCallback.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchModelCheckPointCallback._patch_model_checkpoint()
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
is_keras = 'keras' in sys.modules
is_tf_keras = 'tensorflow' in sys.modules
callbacks = None
if is_keras:
try:
import keras.callbacks as callbacks
except ImportError:
is_keras = False
if not is_keras and is_tf_keras:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
import tensorflow.python.keras.callbacks as callbacks
except ImportError:
is_tf_keras = False
callbacks = None
# we have nothing, quit
if not is_keras and not is_tf_keras:
return
try:
# only patch once
if PatchModelCheckPointCallback.__original_getattribute is None and callbacks is not None:
PatchModelCheckPointCallback.__original_getattribute = callbacks.ModelCheckpoint.__getattribute__
callbacks.ModelCheckpoint.__getattribute__ = PatchModelCheckPointCallback._patched_getattribute
setattr(callbacks.ModelCheckpoint, 'trains',
property(PatchModelCheckPointCallback.trains_object))
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _patched_getattribute(self, attr):
get_base = PatchModelCheckPointCallback.__original_getattribute
# no main task, zero chance we have an Trains event logger
if PatchModelCheckPointCallback.__main_task is None:
return get_base(self, attr)
# check if we already have an Trains event logger
__dict__ = get_base(self, '__dict__')
if 'model' not in __dict__ or \
isinstance(__dict__['model'], _ModelAdapter):
return get_base(self, attr)
# patch the events writer field, and add a double Event Logger (Trains and original)
base_model = __dict__['model']
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
output_model = OutputModel(
PatchModelCheckPointCallback.__main_task,
config_text=defaults_dict.get('config_text'),
config_dict=defaults_dict.get('config_dict'),
name=defaults_dict.get('name'),
comment=defaults_dict.get('comment'),
label_enumeration=defaults_dict.get('label_enumeration') or
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
output_model.set_upload_destination(
PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False))
trains_model = _ModelAdapter(base_model, output_model)
# order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['model'] = trains_model
return get_base(self, attr)
class PatchTensorFlowEager(object):
__main_task = None
__original_fn_scalar = None
__original_fn_hist = None
__original_fn_image = None
__trains_event_writer = None
defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50)
@staticmethod
def update_current_task(task, **kwargs):
PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
if PatchTensorFlowEager.__original_fn_scalar is not None:
return
if 'tensorflow' in sys.modules:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.ops import gen_summary_ops
PatchTensorFlowEager.__original_fn_scalar = gen_summary_ops.write_scalar_summary
gen_summary_ops.write_scalar_summary = PatchTensorFlowEager._write_scalar_summary
PatchTensorFlowEager.__original_fn_image = gen_summary_ops.write_image_summary
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
except ImportError:
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _get_event_writer():
if not PatchTensorFlowEager.__main_task:
return None
if PatchTensorFlowEager.__trains_event_writer is None:
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer
@staticmethod
def trains_object(self):
return PatchTensorFlowEager.__trains_event_writer
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
if event_writer:
try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
@staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
if event_writer:
try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
if event_writer:
try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
max_keep_images=max_images)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs)
class PatchKerasModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchKerasModelIO.__main_task = task
PatchKerasModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
if 'keras' in sys.modules:
try:
from keras.engine.network import Network
except ImportError:
Network = None
try:
from keras.engine.sequential import Sequential
except ImportError:
Sequential = None
try:
from keras import models as keras_saving
except ImportError:
keras_saving = None
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
if 'tensorflow' in sys.modules:
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.keras.engine.network import Network
except ImportError:
Network = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.keras.engine.sequential import Sequential
except ImportError:
Sequential = None
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.keras import models as keras_saving
except ImportError:
keras_saving = None
PatchKerasModelIO._patch_io_calls(Network, Sequential, keras_saving)
@staticmethod
def _patch_io_calls(Network, Sequential, keras_saving):
try:
# only patch once
if not PatchKerasModelIO.__patched:
PatchKerasModelIO.__patched = True
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _updated_config(original_fn, self):
config = original_fn(self)
# check if we have main task
if PatchKerasModelIO.__main_task is None:
return config
try:
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None
# check if object already has InputModel
model_name_id = config.get('name', getattr(self, 'name', 'unknown'))
if self.trains_out_model is not None:
self.trains_out_model.config_dict = config
else:
# todo: support multiple models for the same task
self.trains_out_model = OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
return config
@staticmethod
def _from_config(original_fn, *args, **kwargs):
try:
self = original_fn(*args, **kwargs)
except Exception as ex:
if not running_remotely():
raise ex
self = _Empty()
# check if we have main task
if PatchKerasModelIO.__main_task is None:
return self
try:
# check if object already has InputModel
if not hasattr(self, 'trains_in_model'):
self.trains_in_model = None
# get config
config_dict = kwargs['config'] if 'config' in kwargs else args[0]
# check if object already has InputModel
self.trains_in_model = InputModel.empty(
config_dict=config_dict,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
)
# todo: support multiple models for the same task
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the configuration
if running_remotely():
# reload the model
model_config = self.trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
if 'config' in kwargs:
kwargs['config'] = model_config
else:
args = (model_config,) + args[1:]
model = original_fn(*args, **kwargs)
model.trains_in_model = self.trains_in_model
return model
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
return self
@staticmethod
def _load_weights(original_fn, self, *args, **kwargs):
# check if we have main task
if PatchKerasModelIO.__main_task is None:
return original_fn(self, *args, **kwargs)
# get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
if running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
if 'filepath' in kwargs:
kwargs['filepath'] = filepath
else:
args = (filepath,) + args[1:]
# load model
return original_fn(self, *args, **kwargs)
# try to load the files, if something happened exception will be raised before we register the file
model = original_fn(self, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task)
return model
@staticmethod
def _save(original_fn, self, *args, **kwargs):
if hasattr(self, 'trains_out_model'):
self.trains_out_model._processed = False
original_fn(self, *args, **kwargs)
# no need to specially call, because the original save uses "save_model" which we overload
if not hasattr(self, 'trains_out_model') or not self.trains_out_model._processed:
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod
def _save_weights(original_fn, self, *args, **kwargs):
original_fn(self, *args, **kwargs)
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod
def _update_outputmodel(self, *args, **kwargs):
# check if we have main task
if PatchKerasModelIO.__main_task is None:
return
try:
# get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
# this will already generate an output model
config = self._updated_config()
# check if object already has InputModel
if not hasattr(self, 'trains_out_model'):
self.trains_out_model = None
# check if object already has InputModel
if self.trains_out_model is not None:
self.trains_out_model.config_dict = config
else:
model_name_id = getattr(self, 'name', 'unknown')
# todo: support multiple models for the same task
self.trains_out_model = OutputModel(
task=PatchKerasModelIO.__main_task,
config_dict=config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
# check if we have output storage
if self.trains_out_model.upload_storage_uri:
self.trains_out_model.update_weights(weights_filename=filepath, auto_delete_file=False)
else:
self.trains_out_model.update_weights(weights_filename=None, register_uri=filepath)
# if anyone asks, we were here
self.trains_out_model._processed = True
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs):
original_fn(model, filepath, *args, **kwargs)
if PatchKerasModelIO.__main_task:
PatchKerasModelIO._update_outputmodel(model, filepath)
@staticmethod
def _load_model(original_fn, filepath, *args, **kwargs):
if not PatchKerasModelIO.__main_task:
return original_fn(filepath, *args, **kwargs)
empty = _Empty()
if running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
model = original_fn(filepath, *args, **kwargs)
else:
model = original_fn(filepath, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
# update the input model object
if empty.trains_in_model:
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
class PatchTensorflowModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchTensorflowModelIO.__main_task = task
PatchTensorflowModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
if PatchTensorflowModelIO.__patched:
return
if 'tensorflow' not in sys.modules:
return
PatchTensorflowModelIO.__patched = True
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.saver import Saver
try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
except Exception:
pass
try:
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
except Exception:
pass
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
try:
# make sure we import the correct version of save
import tensorflow
from tensorflow.saved_model.experimental import save
# actual import
import tensorflow.saved_model.experimental as saved_model
except ImportError:
try:
# make sure we import the correct version of save
import tensorflow
from tensorflow.saved_model import save
# actual import
import tensorflow.saved_mode as saved_model
except ImportError:
saved_model = None
except Exception:
saved_model = None
pass # print('Failed patching tensorflow')
except Exception:
saved_model = None
pass # print('Failed patching tensorflow')
if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
try:
# make sure we import the correct version of save
import tensorflow
# actual import
from tensorflow.saved_model import load
import tensorflow.saved_model as saved_model_load
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
try:
# make sure we import the correct version of save
import tensorflow
# actual import
from tensorflow.saved_model import loader as loader1
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
try:
# make sure we import the correct version of save
import tensorflow
# actual import
from tensorflow.compat.v1.saved_model import loader as loader2
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
try:
import tensorflow
from tensorflow.train import Checkpoint
try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
except Exception:
pass
try:
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
except Exception:
pass
try:
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
except Exception:
pass
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
@staticmethod
def _save(original_fn, self, sess, save_path, *args, **kwargs):
saved_path = original_fn(self, sess, save_path, *args, **kwargs)
if not saved_path:
return saved_path
# store output Model
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
@staticmethod
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
original_fn(obj, export_dir, *args, **kwargs)
# store output Model
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
@staticmethod
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(self, sess, save_path, *args, **kwargs)
if running_remotely():
# register/load model weights
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
# load model
return original_fn(self, sess, save_path, *args, **kwargs)
# load model, if something is wrong, exception will be raised before we register the input model
model = original_fn(self, sess, save_path, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
return model
@staticmethod
def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
# register input model
empty = _Empty()
if running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
else:
# try to load model before registering, it might fail
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
if empty.trains_in_model:
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
@staticmethod
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None:
return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
return checkpoint_path
@staticmethod
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None:
return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
return checkpoint_path
@staticmethod
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(self, save_path, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
model = original_fn(self, save_path, *args, **kwargs)
else:
# try to load model before registering it, in case it fails.
model = original_fn(self, save_path, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
if empty.trains_in_model:
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
try:
f.flush()
except Exception:
pass
else:
filename = None
# if the model a screptive name based on the file name
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model