mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Add initial Tensorflow v2 support (2.0.0rc1)
This commit is contained in:
parent
c44638c8d9
commit
a7eb8476ce
130
examples/tensorflow_v2_mnist.py
Normal file
130
examples/tensorflow_v2_mnist.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.keras.layers import Dense, Flatten, Conv2D
|
||||||
|
from tensorflow.keras import Model
|
||||||
|
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
|
||||||
|
task = Task.init(project_name='examples',
|
||||||
|
task_name='Tensorflow v2 mnist with summaries')
|
||||||
|
|
||||||
|
|
||||||
|
# Load and prepare the MNIST dataset.
|
||||||
|
mnist = tf.keras.datasets.mnist
|
||||||
|
|
||||||
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||||
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
|
||||||
|
# Add a channels dimension
|
||||||
|
x_train = x_train[..., tf.newaxis]
|
||||||
|
x_test = x_test[..., tf.newaxis]
|
||||||
|
|
||||||
|
# Use tf.data to batch and shuffle the dataset
|
||||||
|
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
|
||||||
|
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
|
||||||
|
|
||||||
|
|
||||||
|
# Build the tf.keras model using the Keras model subclassing API
|
||||||
|
class MyModel(Model):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.conv1 = Conv2D(32, 3, activation='relu')
|
||||||
|
self.flatten = Flatten()
|
||||||
|
self.d1 = Dense(128, activation='relu')
|
||||||
|
self.d2 = Dense(10, activation='softmax')
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.d1(x)
|
||||||
|
return self.d2(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Create an instance of the model
|
||||||
|
model = MyModel()
|
||||||
|
|
||||||
|
# Choose an optimizer and loss function for training
|
||||||
|
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
optimizer = tf.keras.optimizers.Adam()
|
||||||
|
|
||||||
|
# Select metrics to measure the loss and the accuracy of the model.
|
||||||
|
# These metrics accumulate the values over epochs and then print the overall result.
|
||||||
|
train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32)
|
||||||
|
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||||
|
|
||||||
|
test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32)
|
||||||
|
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
|
||||||
|
|
||||||
|
# Use tf.GradientTape to train the model
|
||||||
|
@tf.function
|
||||||
|
def train_step(images, labels):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = model(images)
|
||||||
|
loss = loss_object(labels, predictions)
|
||||||
|
gradients = tape.gradient(loss, model.trainable_variables)
|
||||||
|
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||||
|
|
||||||
|
train_loss(loss)
|
||||||
|
train_accuracy(labels, predictions)
|
||||||
|
|
||||||
|
|
||||||
|
# Test the model
|
||||||
|
@tf.function
|
||||||
|
def test_step(images, labels):
|
||||||
|
predictions = model(images)
|
||||||
|
t_loss = loss_object(labels, predictions)
|
||||||
|
|
||||||
|
test_loss(t_loss)
|
||||||
|
test_accuracy(labels, predictions)
|
||||||
|
|
||||||
|
|
||||||
|
# Set up summary writers to write the summaries to disk in a different logs directory
|
||||||
|
train_log_dir = '/tmp/logs/gradient_tape/train'
|
||||||
|
test_log_dir = '/tmp/logs/gradient_tape/test'
|
||||||
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
||||||
|
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
|
||||||
|
|
||||||
|
# Set up checkpoints manager
|
||||||
|
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
|
||||||
|
manager = tf.train.CheckpointManager(ckpt, '/tmp/tf_ckpts', max_to_keep=3)
|
||||||
|
ckpt.restore(manager.latest_checkpoint)
|
||||||
|
if manager.latest_checkpoint:
|
||||||
|
print("Restored from {}".format(manager.latest_checkpoint))
|
||||||
|
else:
|
||||||
|
print("Initializing from scratch.")
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
EPOCHS = 5
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
for images, labels in train_ds:
|
||||||
|
train_step(images, labels)
|
||||||
|
with train_summary_writer.as_default():
|
||||||
|
tf.summary.scalar('loss', train_loss.result(), step=epoch)
|
||||||
|
tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
|
||||||
|
|
||||||
|
ckpt.step.assign_add(1)
|
||||||
|
if int(ckpt.step) % 1 == 0:
|
||||||
|
save_path = manager.save()
|
||||||
|
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
|
||||||
|
|
||||||
|
for test_images, test_labels in test_ds:
|
||||||
|
test_step(test_images, test_labels)
|
||||||
|
with test_summary_writer.as_default():
|
||||||
|
tf.summary.scalar('loss', test_loss.result(), step=epoch)
|
||||||
|
tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
|
||||||
|
|
||||||
|
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
|
||||||
|
print(template.format(epoch+1,
|
||||||
|
train_loss.result(),
|
||||||
|
train_accuracy.result()*100,
|
||||||
|
test_loss.result(),
|
||||||
|
test_accuracy.result()*100))
|
||||||
|
|
||||||
|
# Reset the metrics for the next epoch
|
||||||
|
train_loss.reset_states()
|
||||||
|
train_accuracy.reset_states()
|
||||||
|
test_loss.reset_states()
|
||||||
|
test_accuracy.reset_states()
|
@ -8,6 +8,7 @@ from io import BytesIO
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...debugging.log import LoggerRoot
|
from ...debugging.log import LoggerRoot
|
||||||
@ -22,6 +23,16 @@ except ImportError:
|
|||||||
MessageToDict = None
|
MessageToDict = None
|
||||||
|
|
||||||
|
|
||||||
|
class TensorflowBinding(object):
|
||||||
|
@classmethod
|
||||||
|
def update_current_task(cls, task):
|
||||||
|
PatchSummaryToEventTransformer.update_current_task(task)
|
||||||
|
PatchTensorFlowEager.update_current_task(task)
|
||||||
|
PatchKerasModelIO.update_current_task(task)
|
||||||
|
PatchTensorflowModelIO.update_current_task(task)
|
||||||
|
PatchTensorflow2ModelIO.update_current_task(task)
|
||||||
|
|
||||||
|
|
||||||
class IsTensorboardInit(object):
|
class IsTensorboardInit(object):
|
||||||
_tensorboard_initialized = False
|
_tensorboard_initialized = False
|
||||||
|
|
||||||
@ -156,6 +167,7 @@ class EventTrainsWriter(object):
|
|||||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||||
self.histogram_granularity = histogram_granularity
|
self.histogram_granularity = histogram_granularity
|
||||||
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||||
|
self._histogram_update_call_counter = 0
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._visualization_mode = 'RGB' # 'BGR'
|
self._visualization_mode = 'RGB' # 'BGR'
|
||||||
self._variants = defaultdict(lambda: ())
|
self._variants = defaultdict(lambda: ())
|
||||||
@ -168,12 +180,18 @@ class EventTrainsWriter(object):
|
|||||||
def _decode_image(self, img_str, width, height, color_channels):
|
def _decode_image(self, img_str, width, height, color_channels):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
imdata = base64.b64decode(img_str)
|
if isinstance(img_str, bytes):
|
||||||
|
imdata = img_str
|
||||||
|
else:
|
||||||
|
imdata = base64.b64decode(img_str)
|
||||||
output = BytesIO(imdata)
|
output = BytesIO(imdata)
|
||||||
im = Image.open(output)
|
im = Image.open(output)
|
||||||
image = np.asarray(im)
|
image = np.asarray(im)
|
||||||
output.close()
|
output.close()
|
||||||
val = image.reshape(height, width, -1).astype(np.uint8)
|
if height > 0 and width > 0:
|
||||||
|
val = image.reshape(height, width, -1).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
val = image.astype(np.uint8)
|
||||||
if val.ndim == 3 and val.shape[2] == 3:
|
if val.ndim == 3 and val.shape[2] == 3:
|
||||||
if self._visualization_mode == 'BGR':
|
if self._visualization_mode == 'BGR':
|
||||||
val = val[:, :, [2, 1, 0]]
|
val = val[:, :, [2, 1, 0]]
|
||||||
@ -187,7 +205,7 @@ class EventTrainsWriter(object):
|
|||||||
else:
|
else:
|
||||||
val = val[:, :, [0, 1, 2]]
|
val = val[:, :, [0, 1, 2]]
|
||||||
except Exception:
|
except Exception:
|
||||||
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]'
|
LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
|
||||||
% (width, height, color_channels))
|
% (width, height, color_channels))
|
||||||
val = None
|
val = None
|
||||||
return val
|
return val
|
||||||
@ -281,7 +299,9 @@ class EventTrainsWriter(object):
|
|||||||
return _cur_idx
|
return _cur_idx
|
||||||
|
|
||||||
# only collect histogram every specific interval
|
# only collect histogram every specific interval
|
||||||
if step % self.report_freq != 0 or step < self.report_freq - 1:
|
self._histogram_update_call_counter += 1
|
||||||
|
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||||
|
self._histogram_update_call_counter < self.report_freq - 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# generate forward matrix of the histograms
|
# generate forward matrix of the histograms
|
||||||
@ -306,6 +326,8 @@ class EventTrainsWriter(object):
|
|||||||
|
|
||||||
# add current sample, if not already here
|
# add current sample, if not already here
|
||||||
hist_iters = np.append(hist_iters, step)
|
hist_iters = np.append(hist_iters, step)
|
||||||
|
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
|
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
|
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
|
||||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||||
hist_list.append(hist)
|
hist_list.append(hist)
|
||||||
@ -422,7 +444,7 @@ class EventTrainsWriter(object):
|
|||||||
msg_dict.pop('wallTime', None)
|
msg_dict.pop('wallTime', None)
|
||||||
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
|
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
|
||||||
keys_list = ', '.join(keys_list)
|
keys_list = ', '.join(keys_list)
|
||||||
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list)
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||||
return
|
return
|
||||||
value_dicts = summary.get('value')
|
value_dicts = summary.get('value')
|
||||||
walltime = walltime or msg_dict.get('step')
|
walltime = walltime or msg_dict.get('step')
|
||||||
@ -434,19 +456,19 @@ class EventTrainsWriter(object):
|
|||||||
step = int(event.step)
|
step = int(event.step)
|
||||||
else:
|
else:
|
||||||
step = 0
|
step = 0
|
||||||
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
|
||||||
else:
|
else:
|
||||||
step = int(step)
|
step = int(step)
|
||||||
self._max_step = max(self._max_step, step)
|
self._max_step = max(self._max_step, step)
|
||||||
if value_dicts is None:
|
if value_dicts is None:
|
||||||
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'")
|
LoggerRoot.get_base_logger(TensorflowBinding).debug("Summary arrived without 'value'")
|
||||||
return
|
return
|
||||||
|
|
||||||
for vdict in value_dicts:
|
for vdict in value_dicts:
|
||||||
tag = vdict.pop('tag', None)
|
tag = vdict.pop('tag', None)
|
||||||
if tag is None:
|
if tag is None:
|
||||||
# we should not get here
|
# we should not get here
|
||||||
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s'
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
|
||||||
% ', '.join(vdict.keys()))
|
% ', '.join(vdict.keys()))
|
||||||
continue
|
continue
|
||||||
metric, values = get_data(vdict, supported_metrics)
|
metric, values = get_data(vdict, supported_metrics)
|
||||||
@ -463,7 +485,7 @@ class EventTrainsWriter(object):
|
|||||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||||
self._add_plot(tag, step, values, vdict)
|
self._add_plot(tag, step, values, vdict)
|
||||||
else:
|
else:
|
||||||
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]'
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
|
||||||
% (tag, ', '.join(vdict.keys)))
|
% (tag, ', '.join(vdict.keys)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -594,7 +616,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
setattr(SummaryToEventTransformer, 'trains',
|
setattr(SummaryToEventTransformer, 'trains',
|
||||||
property(PatchSummaryToEventTransformer.trains_object))
|
property(PatchSummaryToEventTransformer.trains_object))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().debug(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
try:
|
try:
|
||||||
@ -608,7 +630,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().debug(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||||
|
|
||||||
if 'tensorboardX' in sys.modules:
|
if 'tensorboardX' in sys.modules:
|
||||||
try:
|
try:
|
||||||
@ -624,7 +646,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().debug(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||||
|
|
||||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||||
try:
|
try:
|
||||||
@ -638,7 +660,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().debug(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_add_eventT(self, *args, **kwargs):
|
def _patched_add_eventT(self, *args, **kwargs):
|
||||||
@ -722,7 +744,7 @@ class _ModelAdapter(object):
|
|||||||
super(_ModelAdapter, self).__init__()
|
super(_ModelAdapter, self).__init__()
|
||||||
super(_ModelAdapter, self).__setattr__('_model', model)
|
super(_ModelAdapter, self).__setattr__('_model', model)
|
||||||
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
|
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
|
||||||
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger())
|
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger(TensorflowBinding))
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self._model, attr)
|
return getattr(self._model, attr)
|
||||||
@ -805,7 +827,7 @@ class PatchModelCheckPointCallback(object):
|
|||||||
property(PatchModelCheckPointCallback.trains_object))
|
property(PatchModelCheckPointCallback.trains_object))
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_getattribute(self, attr):
|
def _patched_getattribute(self, attr):
|
||||||
@ -876,6 +898,8 @@ class PatchTensorFlowEager(object):
|
|||||||
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
||||||
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
||||||
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
||||||
|
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
|
||||||
|
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
|
||||||
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||||
gen_summary_ops.create_summary_file_writer)
|
gen_summary_ops.create_summary_file_writer)
|
||||||
gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__,
|
gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||||
@ -883,7 +907,7 @@ class PatchTensorFlowEager(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().debug(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_event_writer(writer):
|
def _get_event_writer(writer):
|
||||||
@ -903,14 +927,40 @@ class PatchTensorFlowEager(object):
|
|||||||
def trains_object(self):
|
def trains_object(self):
|
||||||
return PatchTensorFlowEager.__trains_event_writer
|
return PatchTensorFlowEager.__trains_event_writer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||||
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
|
if event_writer:
|
||||||
|
try:
|
||||||
|
plugin_type = summary_metadata.decode()
|
||||||
|
if plugin_type.endswith('scalars'):
|
||||||
|
event_writer._add_scalar(tag=str(tag),
|
||||||
|
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
scalar_data=tensor.numpy())
|
||||||
|
elif plugin_type.endswith('images'):
|
||||||
|
img_data_np = tensor.numpy()
|
||||||
|
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
|
||||||
|
tag=tag, step=step, **kwargs)
|
||||||
|
elif plugin_type.endswith('histograms'):
|
||||||
|
PatchTensorFlowEager._add_histogram_event_helper(
|
||||||
|
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
hist_data=tensor.numpy())
|
||||||
|
else:
|
||||||
|
pass # print('unsupported plugin_type', plugin_type)
|
||||||
|
except Exception as ex:
|
||||||
|
pass
|
||||||
|
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
event_writer._add_scalar(tag=str(tag),
|
||||||
|
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
scalar_data=value.numpy())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -918,9 +968,11 @@ class PatchTensorFlowEager(object):
|
|||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
PatchTensorFlowEager._add_histogram_event_helper(
|
||||||
|
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
hist_data=values.numpy())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -928,13 +980,45 @@ class PatchTensorFlowEager(object):
|
|||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(),
|
||||||
max_keep_images=max_images)
|
tag=tag, step=step, **kwargs)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
|
||||||
|
if isinstance(hist_data, dict):
|
||||||
|
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
|
||||||
|
return
|
||||||
|
|
||||||
|
# prepare the dictionary, assume numpy
|
||||||
|
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
|
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
|
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||||
|
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||||
|
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
|
||||||
|
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
|
||||||
|
(len(img_data_np[0]) < 10 and len(img_data_np[1]) < 10):
|
||||||
|
# this is just for making sure these are actually valid numbers
|
||||||
|
width = int(img_data_np[0].decode())
|
||||||
|
height = int(img_data_np[1].decode())
|
||||||
|
for i in range(2, img_data_np.size):
|
||||||
|
img_data = {'width': -1, 'height': -1,
|
||||||
|
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
||||||
|
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
|
||||||
|
event_writer._add_image(tag=image_tag,
|
||||||
|
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
img_data=img_data)
|
||||||
|
else:
|
||||||
|
event_writer._add_image_numpy(tag=str(tag),
|
||||||
|
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
img_data_np=img_data_np,
|
||||||
|
max_keep_images=kwargs.get('max_images'))
|
||||||
|
|
||||||
class PatchKerasModelIO(object):
|
class PatchKerasModelIO(object):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
@ -1029,7 +1113,7 @@ class PatchKerasModelIO(object):
|
|||||||
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
||||||
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _updated_config(original_fn, self):
|
def _updated_config(original_fn, self):
|
||||||
@ -1057,7 +1141,7 @@ class PatchKerasModelIO(object):
|
|||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -1107,7 +1191,7 @@ class PatchKerasModelIO(object):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -1189,7 +1273,7 @@ class PatchKerasModelIO(object):
|
|||||||
# if anyone asks, we were here
|
# if anyone asks, we were here
|
||||||
self.trains_out_model._processed = True
|
self.trains_out_model._processed = True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger().warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
||||||
@ -1261,7 +1345,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1282,10 +1366,9 @@ class PatchTensorflowModelIO(object):
|
|||||||
saved_model = None
|
saved_model = None
|
||||||
except Exception:
|
except Exception:
|
||||||
saved_model = None
|
saved_model = None
|
||||||
pass # print('Failed patching tensorflow')
|
|
||||||
except Exception:
|
except Exception:
|
||||||
saved_model = None
|
saved_model = None
|
||||||
pass # print('Failed patching tensorflow')
|
|
||||||
|
|
||||||
if saved_model is not None:
|
if saved_model is not None:
|
||||||
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
||||||
@ -1301,7 +1384,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1313,7 +1396,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1325,7 +1408,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1349,7 +1432,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
||||||
@ -1457,3 +1540,81 @@ class PatchTensorflowModelIO(object):
|
|||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class PatchTensorflow2ModelIO(object):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **kwargs):
|
||||||
|
PatchTensorflow2ModelIO.__main_task = task
|
||||||
|
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
||||||
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_checkpoint():
|
||||||
|
if PatchTensorflow2ModelIO.__patched:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'tensorflow' not in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
PatchTensorflow2ModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import tensorflow
|
||||||
|
from tensorflow.python.training.tracking import util
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||||
|
PatchTensorflow2ModelIO._save)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
|
||||||
|
PatchTensorflow2ModelIO._restore)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
|
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
|
# store output Model
|
||||||
|
try:
|
||||||
|
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
||||||
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _restore(original_fn, self, save_path, *args, **kwargs):
|
||||||
|
if PatchTensorflow2ModelIO.__main_task is None:
|
||||||
|
return original_fn(self, save_path, *args, **kwargs)
|
||||||
|
|
||||||
|
# Hack: disabled
|
||||||
|
if False and running_remotely():
|
||||||
|
# register/load model weights
|
||||||
|
try:
|
||||||
|
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# load model
|
||||||
|
return original_fn(self, save_path, *args, **kwargs)
|
||||||
|
|
||||||
|
# load model, if something is wrong, exception will be raised before we register the input model
|
||||||
|
model = original_fn(self, save_path, *args, **kwargs)
|
||||||
|
# register/load model weights
|
||||||
|
try:
|
||||||
|
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
direct_access: [
|
direct_access: [
|
||||||
# Objects matching are considered to be available for direct access, i.e. they will not eb downloaded
|
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
|
||||||
# or cached, and any download request will return a direct reference.
|
# or cached, and any download request will return a direct reference.
|
||||||
# Objects are specified in glob format, available for url and content_type.
|
# Objects are specified in glob format, available for url and content_type.
|
||||||
{ url: "file://*" } # file-urls are always directly referenced
|
{ url: "file://*" } # file-urls are always directly referenced
|
||||||
|
@ -34,8 +34,7 @@ from .binding.absl_bind import PatchAbsl
|
|||||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||||
argparser_update_currenttask
|
argparser_update_currenttask
|
||||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||||
PatchKerasModelIO, PatchTensorflowModelIO
|
|
||||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
from .utilities.resource_monitor import ResourceMonitor
|
from .utilities.resource_monitor import ResourceMonitor
|
||||||
@ -252,10 +251,7 @@ class Task(_Task):
|
|||||||
PatchedJoblib.update_current_task(task)
|
PatchedJoblib.update_current_task(task)
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
PatchSummaryToEventTransformer.update_current_task(task)
|
TensorflowBinding.update_current_task(task)
|
||||||
PatchTensorFlowEager.update_current_task(task)
|
|
||||||
PatchKerasModelIO.update_current_task(task)
|
|
||||||
PatchTensorflowModelIO.update_current_task(task)
|
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
PatchXGBoostModelIO.update_current_task(task)
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if auto_resource_monitoring:
|
if auto_resource_monitoring:
|
||||||
@ -346,7 +342,7 @@ class Task(_Task):
|
|||||||
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
|
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
|
||||||
if not default_project_name or not default_task_name:
|
if not default_project_name or not default_task_name:
|
||||||
# get project name and task name from repository name and entry_point
|
# get project name and task name from repository name and entry_point
|
||||||
result = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||||
if not default_project_name:
|
if not default_project_name:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1016,7 +1012,14 @@ class Task(_Task):
|
|||||||
if self._detect_repo_async_thread:
|
if self._detect_repo_async_thread:
|
||||||
try:
|
try:
|
||||||
if self._detect_repo_async_thread.is_alive():
|
if self._detect_repo_async_thread.is_alive():
|
||||||
|
self.log.info('Waiting for repository detection and full package requirement analysis')
|
||||||
self._detect_repo_async_thread.join(timeout=timeout)
|
self._detect_repo_async_thread.join(timeout=timeout)
|
||||||
|
# because join has no return value
|
||||||
|
if self._detect_repo_async_thread.is_alive():
|
||||||
|
self.log.info('Repository and package analysis timed out ({} sec), '
|
||||||
|
'giving up'.format(timeout))
|
||||||
|
else:
|
||||||
|
self.log.info('Finished repository detection and package analysis')
|
||||||
self._detect_repo_async_thread = None
|
self._detect_repo_async_thread = None
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user