mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Add initial Tensorflow v2 support (2.0.0rc1)
This commit is contained in:
parent
c44638c8d9
commit
a7eb8476ce
130
examples/tensorflow_v2_mnist.py
Normal file
130
examples/tensorflow_v2_mnist.py
Normal file
@ -0,0 +1,130 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.keras.layers import Dense, Flatten, Conv2D
|
||||
from tensorflow.keras import Model
|
||||
|
||||
from trains import Task
|
||||
|
||||
|
||||
task = Task.init(project_name='examples',
|
||||
task_name='Tensorflow v2 mnist with summaries')
|
||||
|
||||
|
||||
# Load and prepare the MNIST dataset.
|
||||
mnist = tf.keras.datasets.mnist
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||
|
||||
# Add a channels dimension
|
||||
x_train = x_train[..., tf.newaxis]
|
||||
x_test = x_test[..., tf.newaxis]
|
||||
|
||||
# Use tf.data to batch and shuffle the dataset
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
|
||||
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
|
||||
|
||||
|
||||
# Build the tf.keras model using the Keras model subclassing API
|
||||
class MyModel(Model):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.conv1 = Conv2D(32, 3, activation='relu')
|
||||
self.flatten = Flatten()
|
||||
self.d1 = Dense(128, activation='relu')
|
||||
self.d2 = Dense(10, activation='softmax')
|
||||
|
||||
def call(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.flatten(x)
|
||||
x = self.d1(x)
|
||||
return self.d2(x)
|
||||
|
||||
|
||||
# Create an instance of the model
|
||||
model = MyModel()
|
||||
|
||||
# Choose an optimizer and loss function for training
|
||||
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
optimizer = tf.keras.optimizers.Adam()
|
||||
|
||||
# Select metrics to measure the loss and the accuracy of the model.
|
||||
# These metrics accumulate the values over epochs and then print the overall result.
|
||||
train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32)
|
||||
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||
|
||||
test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32)
|
||||
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
|
||||
|
||||
# Use tf.GradientTape to train the model
|
||||
@tf.function
|
||||
def train_step(images, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = model(images)
|
||||
loss = loss_object(labels, predictions)
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
train_loss(loss)
|
||||
train_accuracy(labels, predictions)
|
||||
|
||||
|
||||
# Test the model
|
||||
@tf.function
|
||||
def test_step(images, labels):
|
||||
predictions = model(images)
|
||||
t_loss = loss_object(labels, predictions)
|
||||
|
||||
test_loss(t_loss)
|
||||
test_accuracy(labels, predictions)
|
||||
|
||||
|
||||
# Set up summary writers to write the summaries to disk in a different logs directory
|
||||
train_log_dir = '/tmp/logs/gradient_tape/train'
|
||||
test_log_dir = '/tmp/logs/gradient_tape/test'
|
||||
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
||||
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
|
||||
|
||||
# Set up checkpoints manager
|
||||
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
|
||||
manager = tf.train.CheckpointManager(ckpt, '/tmp/tf_ckpts', max_to_keep=3)
|
||||
ckpt.restore(manager.latest_checkpoint)
|
||||
if manager.latest_checkpoint:
|
||||
print("Restored from {}".format(manager.latest_checkpoint))
|
||||
else:
|
||||
print("Initializing from scratch.")
|
||||
|
||||
# Start training
|
||||
EPOCHS = 5
|
||||
for epoch in range(EPOCHS):
|
||||
for images, labels in train_ds:
|
||||
train_step(images, labels)
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('loss', train_loss.result(), step=epoch)
|
||||
tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
|
||||
|
||||
ckpt.step.assign_add(1)
|
||||
if int(ckpt.step) % 1 == 0:
|
||||
save_path = manager.save()
|
||||
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
|
||||
|
||||
for test_images, test_labels in test_ds:
|
||||
test_step(test_images, test_labels)
|
||||
with test_summary_writer.as_default():
|
||||
tf.summary.scalar('loss', test_loss.result(), step=epoch)
|
||||
tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
|
||||
|
||||
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
|
||||
print(template.format(epoch+1,
|
||||
train_loss.result(),
|
||||
train_accuracy.result()*100,
|
||||
test_loss.result(),
|
||||
test_accuracy.result()*100))
|
||||
|
||||
# Reset the metrics for the next epoch
|
||||
train_loss.reset_states()
|
||||
train_accuracy.reset_states()
|
||||
test_loss.reset_states()
|
||||
test_accuracy.reset_states()
|
@ -8,6 +8,7 @@ from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from PIL import Image
|
||||
|
||||
from ...debugging.log import LoggerRoot
|
||||
@ -22,6 +23,16 @@ except ImportError:
|
||||
MessageToDict = None
|
||||
|
||||
|
||||
class TensorflowBinding(object):
|
||||
@classmethod
|
||||
def update_current_task(cls, task):
|
||||
PatchSummaryToEventTransformer.update_current_task(task)
|
||||
PatchTensorFlowEager.update_current_task(task)
|
||||
PatchKerasModelIO.update_current_task(task)
|
||||
PatchTensorflowModelIO.update_current_task(task)
|
||||
PatchTensorflow2ModelIO.update_current_task(task)
|
||||
|
||||
|
||||
class IsTensorboardInit(object):
|
||||
_tensorboard_initialized = False
|
||||
|
||||
@ -156,6 +167,7 @@ class EventTrainsWriter(object):
|
||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||
self.histogram_granularity = histogram_granularity
|
||||
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||
self._histogram_update_call_counter = 0
|
||||
self._logger = logger
|
||||
self._visualization_mode = 'RGB' # 'BGR'
|
||||
self._variants = defaultdict(lambda: ())
|
||||
@ -168,12 +180,18 @@ class EventTrainsWriter(object):
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
imdata = base64.b64decode(img_str)
|
||||
if isinstance(img_str, bytes):
|
||||
imdata = img_str
|
||||
else:
|
||||
imdata = base64.b64decode(img_str)
|
||||
output = BytesIO(imdata)
|
||||
im = Image.open(output)
|
||||
image = np.asarray(im)
|
||||
output.close()
|
||||
val = image.reshape(height, width, -1).astype(np.uint8)
|
||||
if height > 0 and width > 0:
|
||||
val = image.reshape(height, width, -1).astype(np.uint8)
|
||||
else:
|
||||
val = image.astype(np.uint8)
|
||||
if val.ndim == 3 and val.shape[2] == 3:
|
||||
if self._visualization_mode == 'BGR':
|
||||
val = val[:, :, [2, 1, 0]]
|
||||
@ -187,7 +205,7 @@ class EventTrainsWriter(object):
|
||||
else:
|
||||
val = val[:, :, [0, 1, 2]]
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]'
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
|
||||
% (width, height, color_channels))
|
||||
val = None
|
||||
return val
|
||||
@ -281,7 +299,9 @@ class EventTrainsWriter(object):
|
||||
return _cur_idx
|
||||
|
||||
# only collect histogram every specific interval
|
||||
if step % self.report_freq != 0 or step < self.report_freq - 1:
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
|
||||
# generate forward matrix of the histograms
|
||||
@ -306,6 +326,8 @@ class EventTrainsWriter(object):
|
||||
|
||||
# add current sample, if not already here
|
||||
hist_iters = np.append(hist_iters, step)
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
|
||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||
hist_list.append(hist)
|
||||
@ -422,7 +444,7 @@ class EventTrainsWriter(object):
|
||||
msg_dict.pop('wallTime', None)
|
||||
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
|
||||
keys_list = ', '.join(keys_list)
|
||||
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||
return
|
||||
value_dicts = summary.get('value')
|
||||
walltime = walltime or msg_dict.get('step')
|
||||
@ -434,19 +456,19 @@ class EventTrainsWriter(object):
|
||||
step = int(event.step)
|
||||
else:
|
||||
step = 0
|
||||
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
|
||||
else:
|
||||
step = int(step)
|
||||
self._max_step = max(self._max_step, step)
|
||||
if value_dicts is None:
|
||||
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'")
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug("Summary arrived without 'value'")
|
||||
return
|
||||
|
||||
for vdict in value_dicts:
|
||||
tag = vdict.pop('tag', None)
|
||||
if tag is None:
|
||||
# we should not get here
|
||||
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s'
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
|
||||
% ', '.join(vdict.keys()))
|
||||
continue
|
||||
metric, values = get_data(vdict, supported_metrics)
|
||||
@ -463,7 +485,7 @@ class EventTrainsWriter(object):
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||
self._add_plot(tag, step, values, vdict)
|
||||
else:
|
||||
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]'
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
|
||||
% (tag, ', '.join(vdict.keys)))
|
||||
continue
|
||||
|
||||
@ -594,7 +616,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
setattr(SummaryToEventTransformer, 'trains',
|
||||
property(PatchSummaryToEventTransformer.trains_object))
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
try:
|
||||
@ -608,7 +630,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
if 'tensorboardX' in sys.modules:
|
||||
try:
|
||||
@ -624,7 +646,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||
try:
|
||||
@ -638,7 +660,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _patched_add_eventT(self, *args, **kwargs):
|
||||
@ -722,7 +744,7 @@ class _ModelAdapter(object):
|
||||
super(_ModelAdapter, self).__init__()
|
||||
super(_ModelAdapter, self).__setattr__('_model', model)
|
||||
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
|
||||
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger())
|
||||
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger(TensorflowBinding))
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self._model, attr)
|
||||
@ -805,7 +827,7 @@ class PatchModelCheckPointCallback(object):
|
||||
property(PatchModelCheckPointCallback.trains_object))
|
||||
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _patched_getattribute(self, attr):
|
||||
@ -876,6 +898,8 @@ class PatchTensorFlowEager(object):
|
||||
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
||||
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
||||
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
||||
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
|
||||
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
|
||||
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||
gen_summary_ops.create_summary_file_writer)
|
||||
gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||
@ -883,7 +907,7 @@ class PatchTensorFlowEager(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer(writer):
|
||||
@ -903,14 +927,40 @@ class PatchTensorFlowEager(object):
|
||||
def trains_object(self):
|
||||
return PatchTensorFlowEager.__trains_event_writer
|
||||
|
||||
@staticmethod
|
||||
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
plugin_type = summary_metadata.decode()
|
||||
if plugin_type.endswith('scalars'):
|
||||
event_writer._add_scalar(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
scalar_data=tensor.numpy())
|
||||
elif plugin_type.endswith('images'):
|
||||
img_data_np = tensor.numpy()
|
||||
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
|
||||
tag=tag, step=step, **kwargs)
|
||||
elif plugin_type.endswith('histograms'):
|
||||
PatchTensorFlowEager._add_histogram_event_helper(
|
||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=tensor.numpy())
|
||||
else:
|
||||
pass # print('unsupported plugin_type', plugin_type)
|
||||
except Exception as ex:
|
||||
pass
|
||||
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||
event_writer._add_scalar(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
scalar_data=value.numpy())
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -918,9 +968,11 @@ class PatchTensorFlowEager(object):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||
PatchTensorFlowEager._add_histogram_event_helper(
|
||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=values.numpy())
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -928,13 +980,45 @@ class PatchTensorFlowEager(object):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||
max_keep_images=max_images)
|
||||
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(),
|
||||
tag=tag, step=step, **kwargs)
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
|
||||
if isinstance(hist_data, dict):
|
||||
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
|
||||
return
|
||||
|
||||
# prepare the dictionary, assume numpy
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
|
||||
|
||||
@staticmethod
|
||||
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
|
||||
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
|
||||
(len(img_data_np[0]) < 10 and len(img_data_np[1]) < 10):
|
||||
# this is just for making sure these are actually valid numbers
|
||||
width = int(img_data_np[0].decode())
|
||||
height = int(img_data_np[1].decode())
|
||||
for i in range(2, img_data_np.size):
|
||||
img_data = {'width': -1, 'height': -1,
|
||||
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
||||
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
|
||||
event_writer._add_image(tag=image_tag,
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
img_data=img_data)
|
||||
else:
|
||||
event_writer._add_image_numpy(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
img_data_np=img_data_np,
|
||||
max_keep_images=kwargs.get('max_images'))
|
||||
|
||||
class PatchKerasModelIO(object):
|
||||
__main_task = None
|
||||
@ -1029,7 +1113,7 @@ class PatchKerasModelIO(object):
|
||||
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
||||
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _updated_config(original_fn, self):
|
||||
@ -1057,7 +1141,7 @@ class PatchKerasModelIO(object):
|
||||
framework=Framework.keras,
|
||||
)
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
return config
|
||||
|
||||
@ -1107,7 +1191,7 @@ class PatchKerasModelIO(object):
|
||||
return model
|
||||
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
return self
|
||||
|
||||
@ -1189,7 +1273,7 @@ class PatchKerasModelIO(object):
|
||||
# if anyone asks, we were here
|
||||
self.trains_out_model._processed = True
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
||||
@ -1261,7 +1345,7 @@ class PatchTensorflowModelIO(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1282,10 +1366,9 @@ class PatchTensorflowModelIO(object):
|
||||
saved_model = None
|
||||
except Exception:
|
||||
saved_model = None
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
except Exception:
|
||||
saved_model = None
|
||||
pass # print('Failed patching tensorflow')
|
||||
|
||||
if saved_model is not None:
|
||||
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
||||
@ -1301,7 +1384,7 @@ class PatchTensorflowModelIO(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1313,7 +1396,7 @@ class PatchTensorflowModelIO(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1325,7 +1408,7 @@ class PatchTensorflowModelIO(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1349,7 +1432,7 @@ class PatchTensorflowModelIO(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching tensorflow')
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
||||
@ -1457,3 +1540,81 @@ class PatchTensorflowModelIO(object):
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class PatchTensorflow2ModelIO(object):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchTensorflow2ModelIO.__main_task = task
|
||||
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_checkpoint():
|
||||
if PatchTensorflow2ModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'tensorflow' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchTensorflow2ModelIO.__patched = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import tensorflow
|
||||
from tensorflow.python.training.tracking import util
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
|
||||
PatchTensorflow2ModelIO._save)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
|
||||
PatchTensorflow2ModelIO._restore)
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||
# store output Model
|
||||
try:
|
||||
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _restore(original_fn, self, save_path, *args, **kwargs):
|
||||
if PatchTensorflow2ModelIO.__main_task is None:
|
||||
return original_fn(self, save_path, *args, **kwargs)
|
||||
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
try:
|
||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
except Exception:
|
||||
pass
|
||||
# load model
|
||||
return original_fn(self, save_path, *args, **kwargs)
|
||||
|
||||
# load model, if something is wrong, exception will be raised before we register the input model
|
||||
model = original_fn(self, save_path, *args, **kwargs)
|
||||
# register/load model weights
|
||||
try:
|
||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
@ -8,7 +8,7 @@
|
||||
}
|
||||
|
||||
direct_access: [
|
||||
# Objects matching are considered to be available for direct access, i.e. they will not eb downloaded
|
||||
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
|
||||
# or cached, and any download request will return a direct reference.
|
||||
# Objects are specified in glob format, available for url and content_type.
|
||||
{ url: "file://*" } # file-urls are always directly referenced
|
||||
|
@ -34,8 +34,7 @@ from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
||||
PatchKerasModelIO, PatchTensorflowModelIO
|
||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
@ -252,10 +251,7 @@ class Task(_Task):
|
||||
PatchedJoblib.update_current_task(task)
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
PatchSummaryToEventTransformer.update_current_task(task)
|
||||
PatchTensorFlowEager.update_current_task(task)
|
||||
PatchKerasModelIO.update_current_task(task)
|
||||
PatchTensorflowModelIO.update_current_task(task)
|
||||
TensorflowBinding.update_current_task(task)
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if auto_resource_monitoring:
|
||||
@ -346,7 +342,7 @@ class Task(_Task):
|
||||
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
|
||||
if not default_project_name or not default_task_name:
|
||||
# get project name and task name from repository name and entry_point
|
||||
result = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||
if not default_project_name:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1016,7 +1012,14 @@ class Task(_Task):
|
||||
if self._detect_repo_async_thread:
|
||||
try:
|
||||
if self._detect_repo_async_thread.is_alive():
|
||||
self.log.info('Waiting for repository detection and full package requirement analysis')
|
||||
self._detect_repo_async_thread.join(timeout=timeout)
|
||||
# because join has no return value
|
||||
if self._detect_repo_async_thread.is_alive():
|
||||
self.log.info('Repository and package analysis timed out ({} sec), '
|
||||
'giving up'.format(timeout))
|
||||
else:
|
||||
self.log.info('Finished repository detection and package analysis')
|
||||
self._detect_repo_async_thread = None
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user