Add initial Tensorflow v2 support (2.0.0rc1)

This commit is contained in:
allegroai 2019-09-27 13:24:04 +03:00
parent c44638c8d9
commit a7eb8476ce
4 changed files with 336 additions and 42 deletions

View File

@ -0,0 +1,130 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
from trains import Task
task = Task.init(project_name='examples',
task_name='Tensorflow v2 mnist with summaries')
# Load and prepare the MNIST dataset.
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# Use tf.data to batch and shuffle the dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Build the tf.keras model using the Keras model subclassing API
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
# Choose an optimizer and loss function for training
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
# Select metrics to measure the loss and the accuracy of the model.
# These metrics accumulate the values over epochs and then print the overall result.
train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# Use tf.GradientTape to train the model
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
# Test the model
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
# Set up summary writers to write the summaries to disk in a different logs directory
train_log_dir = '/tmp/logs/gradient_tape/train'
test_log_dir = '/tmp/logs/gradient_tape/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
# Set up checkpoints manager
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
manager = tf.train.CheckpointManager(ckpt, '/tmp/tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
# Start training
EPOCHS = 5
for epoch in range(EPOCHS):
for images, labels in train_ds:
train_step(images, labels)
with train_summary_writer.as_default():
tf.summary.scalar('loss', train_loss.result(), step=epoch)
tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
ckpt.step.assign_add(1)
if int(ckpt.step) % 1 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
with test_summary_writer.as_default():
tf.summary.scalar('loss', test_loss.result(), step=epoch)
tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch+1,
train_loss.result(),
train_accuracy.result()*100,
test_loss.result(),
test_accuracy.result()*100))
# Reset the metrics for the next epoch
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()

View File

@ -8,6 +8,7 @@ from io import BytesIO
from typing import Any from typing import Any
import numpy as np import numpy as np
import six
from PIL import Image from PIL import Image
from ...debugging.log import LoggerRoot from ...debugging.log import LoggerRoot
@ -22,6 +23,16 @@ except ImportError:
MessageToDict = None MessageToDict = None
class TensorflowBinding(object):
@classmethod
def update_current_task(cls, task):
PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task)
PatchTensorflow2ModelIO.update_current_task(task)
class IsTensorboardInit(object): class IsTensorboardInit(object):
_tensorboard_initialized = False _tensorboard_initialized = False
@ -156,6 +167,7 @@ class EventTrainsWriter(object):
self.image_report_freq = image_report_freq if image_report_freq else report_freq self.image_report_freq = image_report_freq if image_report_freq else report_freq
self.histogram_granularity = histogram_granularity self.histogram_granularity = histogram_granularity
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
self._histogram_update_call_counter = 0
self._logger = logger self._logger = logger
self._visualization_mode = 'RGB' # 'BGR' self._visualization_mode = 'RGB' # 'BGR'
self._variants = defaultdict(lambda: ()) self._variants = defaultdict(lambda: ())
@ -168,12 +180,18 @@ class EventTrainsWriter(object):
def _decode_image(self, img_str, width, height, color_channels): def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if isinstance(img_str, bytes):
imdata = img_str
else:
imdata = base64.b64decode(img_str) imdata = base64.b64decode(img_str)
output = BytesIO(imdata) output = BytesIO(imdata)
im = Image.open(output) im = Image.open(output)
image = np.asarray(im) image = np.asarray(im)
output.close() output.close()
if height > 0 and width > 0:
val = image.reshape(height, width, -1).astype(np.uint8) val = image.reshape(height, width, -1).astype(np.uint8)
else:
val = image.astype(np.uint8)
if val.ndim == 3 and val.shape[2] == 3: if val.ndim == 3 and val.shape[2] == 3:
if self._visualization_mode == 'BGR': if self._visualization_mode == 'BGR':
val = val[:, :, [2, 1, 0]] val = val[:, :, [2, 1, 0]]
@ -187,7 +205,7 @@ class EventTrainsWriter(object):
else: else:
val = val[:, :, [0, 1, 2]] val = val[:, :, [0, 1, 2]]
except Exception: except Exception:
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]' LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
% (width, height, color_channels)) % (width, height, color_channels))
val = None val = None
return val return val
@ -281,7 +299,9 @@ class EventTrainsWriter(object):
return _cur_idx return _cur_idx
# only collect histogram every specific interval # only collect histogram every specific interval
if step % self.report_freq != 0 or step < self.report_freq - 1: self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None return None
# generate forward matrix of the histograms # generate forward matrix of the histograms
@ -306,6 +326,8 @@ class EventTrainsWriter(object):
# add current sample, if not already here # add current sample, if not already here
hist_iters = np.append(hist_iters, step) hist_iters = np.append(hist_iters, step)
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32) hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :] hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist) hist_list.append(hist)
@ -422,7 +444,7 @@ class EventTrainsWriter(object):
msg_dict.pop('wallTime', None) msg_dict.pop('wallTime', None)
keys_list = [key for key in msg_dict.keys() if len(key) > 0] keys_list = [key for key in msg_dict.keys() if len(key) > 0]
keys_list = ', '.join(keys_list) keys_list = ', '.join(keys_list)
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list) LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
return return
value_dicts = summary.get('value') value_dicts = summary.get('value')
walltime = walltime or msg_dict.get('step') walltime = walltime or msg_dict.get('step')
@ -434,19 +456,19 @@ class EventTrainsWriter(object):
step = int(event.step) step = int(event.step)
else: else:
step = 0 step = 0
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step)) LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
else: else:
step = int(step) step = int(step)
self._max_step = max(self._max_step, step) self._max_step = max(self._max_step, step)
if value_dicts is None: if value_dicts is None:
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'") LoggerRoot.get_base_logger(TensorflowBinding).debug("Summary arrived without 'value'")
return return
for vdict in value_dicts: for vdict in value_dicts:
tag = vdict.pop('tag', None) tag = vdict.pop('tag', None)
if tag is None: if tag is None:
# we should not get here # we should not get here
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s' LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
% ', '.join(vdict.keys())) % ', '.join(vdict.keys()))
continue continue
metric, values = get_data(vdict, supported_metrics) metric, values = get_data(vdict, supported_metrics)
@ -463,7 +485,7 @@ class EventTrainsWriter(object):
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict) self._add_plot(tag, step, values, vdict)
else: else:
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]' LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
% (tag, ', '.join(vdict.keys))) % (tag, ', '.join(vdict.keys)))
continue continue
@ -594,7 +616,7 @@ class PatchSummaryToEventTransformer(object):
setattr(SummaryToEventTransformer, 'trains', setattr(SummaryToEventTransformer, 'trains',
property(PatchSummaryToEventTransformer.trains_object)) property(PatchSummaryToEventTransformer.trains_object))
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if 'torch' in sys.modules: if 'torch' in sys.modules:
try: try:
@ -608,7 +630,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if 'tensorboardX' in sys.modules: if 'tensorboardX' in sys.modules:
try: try:
@ -624,7 +646,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if PatchSummaryToEventTransformer.__original_getattributeX is None: if PatchSummaryToEventTransformer.__original_getattributeX is None:
try: try:
@ -638,7 +660,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
@staticmethod @staticmethod
def _patched_add_eventT(self, *args, **kwargs): def _patched_add_eventT(self, *args, **kwargs):
@ -722,7 +744,7 @@ class _ModelAdapter(object):
super(_ModelAdapter, self).__init__() super(_ModelAdapter, self).__init__()
super(_ModelAdapter, self).__setattr__('_model', model) super(_ModelAdapter, self).__setattr__('_model', model)
super(_ModelAdapter, self).__setattr__('_output_model', output_model) super(_ModelAdapter, self).__setattr__('_output_model', output_model)
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger()) super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger(TensorflowBinding))
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self._model, attr) return getattr(self._model, attr)
@ -805,7 +827,7 @@ class PatchModelCheckPointCallback(object):
property(PatchModelCheckPointCallback.trains_object)) property(PatchModelCheckPointCallback.trains_object))
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod @staticmethod
def _patched_getattribute(self, attr): def _patched_getattribute(self, attr):
@ -876,6 +898,8 @@ class PatchTensorFlowEager(object):
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__, gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
gen_summary_ops.create_summary_file_writer) gen_summary_ops.create_summary_file_writer)
gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__, gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__,
@ -883,7 +907,7 @@ class PatchTensorFlowEager(object):
except ImportError: except ImportError:
pass pass
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
@staticmethod @staticmethod
def _get_event_writer(writer): def _get_event_writer(writer):
@ -903,14 +927,40 @@ class PatchTensorFlowEager(object):
def trains_object(self): def trains_object(self):
return PatchTensorFlowEager.__trains_event_writer return PatchTensorFlowEager.__trains_event_writer
@staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
plugin_type = summary_metadata.decode()
if plugin_type.endswith('scalars'):
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
scalar_data=tensor.numpy())
elif plugin_type.endswith('images'):
img_data_np = tensor.numpy()
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
tag=tag, step=step, **kwargs)
elif plugin_type.endswith('histograms'):
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy())
else:
pass # print('unsupported plugin_type', plugin_type)
except Exception as ex:
pass
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
@staticmethod @staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
scalar_data=value.numpy())
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs) return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
@staticmethod @staticmethod
@ -918,9 +968,11 @@ class PatchTensorFlowEager(object):
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=values.numpy())
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@staticmethod @staticmethod
@ -928,13 +980,45 @@ class PatchTensorFlowEager(object):
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(),
max_keep_images=max_images) tag=tag, step=step, **kwargs)
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name, return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs) **kwargs)
@staticmethod
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
if isinstance(hist_data, dict):
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
return
# prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
@staticmethod
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
(len(img_data_np[0]) < 10 and len(img_data_np[1]) < 10):
# this is just for making sure these are actually valid numbers
width = int(img_data_np[0].decode())
height = int(img_data_np[1].decode())
for i in range(2, img_data_np.size):
img_data = {'width': -1, 'height': -1,
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
event_writer._add_image(tag=image_tag,
step=int(step.numpy()) if not isinstance(step, int) else step,
img_data=img_data)
else:
event_writer._add_image_numpy(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
img_data_np=img_data_np,
max_keep_images=kwargs.get('max_images'))
class PatchKerasModelIO(object): class PatchKerasModelIO(object):
__main_task = None __main_task = None
@ -1029,7 +1113,7 @@ class PatchKerasModelIO(object):
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod @staticmethod
def _updated_config(original_fn, self): def _updated_config(original_fn, self):
@ -1057,7 +1141,7 @@ class PatchKerasModelIO(object):
framework=Framework.keras, framework=Framework.keras,
) )
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return config return config
@ -1107,7 +1191,7 @@ class PatchKerasModelIO(object):
return model return model
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return self return self
@ -1189,7 +1273,7 @@ class PatchKerasModelIO(object):
# if anyone asks, we were here # if anyone asks, we were here
self.trains_out_model._processed = True self.trains_out_model._processed = True
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod @staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs): def _save_model(original_fn, model, filepath, *args, **kwargs):
@ -1261,7 +1345,7 @@ class PatchTensorflowModelIO(object):
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
pass # print('Failed patching tensorflow') LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1282,10 +1366,9 @@ class PatchTensorflowModelIO(object):
saved_model = None saved_model = None
except Exception: except Exception:
saved_model = None saved_model = None
pass # print('Failed patching tensorflow')
except Exception: except Exception:
saved_model = None saved_model = None
pass # print('Failed patching tensorflow')
if saved_model is not None: if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model) saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
@ -1301,7 +1384,7 @@ class PatchTensorflowModelIO(object):
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
pass # print('Failed patching tensorflow') LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1313,7 +1396,7 @@ class PatchTensorflowModelIO(object):
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
pass # print('Failed patching tensorflow') LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1325,7 +1408,7 @@ class PatchTensorflowModelIO(object):
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
pass # print('Failed patching tensorflow') LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1349,7 +1432,7 @@ class PatchTensorflowModelIO(object):
except ImportError: except ImportError:
pass pass
except Exception: except Exception:
pass # print('Failed patching tensorflow') LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
@staticmethod @staticmethod
def _save(original_fn, self, sess, save_path, *args, **kwargs): def _save(original_fn, self, sess, save_path, *args, **kwargs):
@ -1457,3 +1540,81 @@ class PatchTensorflowModelIO(object):
pass pass
return model return model
class PatchTensorflow2ModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchTensorflow2ModelIO.__main_task = task
PatchTensorflow2ModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
if PatchTensorflow2ModelIO.__patched:
return
if 'tensorflow' not in sys.modules:
return
PatchTensorflow2ModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.tracking import util
# noinspection PyBroadException
try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
PatchTensorflow2ModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
PatchTensorflow2ModelIO._restore)
except Exception:
pass
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
@staticmethod
def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs)
# store output Model
try:
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
return model
@staticmethod
def _restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflow2ModelIO.__main_task is None:
return original_fn(self, save_path, *args, **kwargs)
# Hack: disabled
if False and running_remotely():
# register/load model weights
try:
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
# load model
return original_fn(self, save_path, *args, **kwargs)
# load model, if something is wrong, exception will be raised before we register the input model
model = original_fn(self, save_path, *args, **kwargs)
# register/load model weights
try:
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
return model

View File

@ -8,7 +8,7 @@
} }
direct_access: [ direct_access: [
# Objects matching are considered to be available for direct access, i.e. they will not eb downloaded # Objects matching are considered to be available for direct access, i.e. they will not be downloaded
# or cached, and any download request will return a direct reference. # or cached, and any download request will return a direct reference.
# Objects are specified in glob format, available for url and content_type. # Objects are specified in glob format, available for url and content_type.
{ url: "file://*" } # file-urls are always directly referenced { url: "file://*" } # file-urls are always directly referenced

View File

@ -34,8 +34,7 @@ from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask argparser_update_currenttask
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \ from .binding.frameworks.tensorflow_bind import TensorflowBinding
PatchKerasModelIO, PatchTensorflowModelIO
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.matplotlib_bind import PatchedMatplotlib from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.resource_monitor import ResourceMonitor from .utilities.resource_monitor import ResourceMonitor
@ -252,10 +251,7 @@ class Task(_Task):
PatchedJoblib.update_current_task(task) PatchedJoblib.update_current_task(task)
PatchedMatplotlib.update_current_task(Task.__main_task) PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task)
PatchSummaryToEventTransformer.update_current_task(task) TensorflowBinding.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task)
PatchPyTorchModelIO.update_current_task(task) PatchPyTorchModelIO.update_current_task(task)
PatchXGBoostModelIO.update_current_task(task) PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring: if auto_resource_monitoring:
@ -346,7 +342,7 @@ class Task(_Task):
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id): def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
result = ScriptInfo.get(create_requirements=False, check_uncommitted=False) result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name: if not default_project_name:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1016,7 +1012,14 @@ class Task(_Task):
if self._detect_repo_async_thread: if self._detect_repo_async_thread:
try: try:
if self._detect_repo_async_thread.is_alive(): if self._detect_repo_async_thread.is_alive():
self.log.info('Waiting for repository detection and full package requirement analysis')
self._detect_repo_async_thread.join(timeout=timeout) self._detect_repo_async_thread.join(timeout=timeout)
# because join has no return value
if self._detect_repo_async_thread.is_alive():
self.log.info('Repository and package analysis timed out ({} sec), '
'giving up'.format(timeout))
else:
self.log.info('Finished repository detection and package analysis')
self._detect_repo_async_thread = None self._detect_repo_async_thread = None
except Exception: except Exception:
pass pass