From a7eb8476ceae0498cb57c1870163aa36ce61efc6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 27 Sep 2019 13:24:04 +0300 Subject: [PATCH] Add initial Tensorflow v2 support (2.0.0rc1) --- examples/tensorflow_v2_mnist.py | 130 +++++++++++ trains/binding/frameworks/tensorflow_bind.py | 229 ++++++++++++++++--- trains/config/default/sdk.conf | 2 +- trains/task.py | 17 +- 4 files changed, 336 insertions(+), 42 deletions(-) create mode 100644 examples/tensorflow_v2_mnist.py diff --git a/examples/tensorflow_v2_mnist.py b/examples/tensorflow_v2_mnist.py new file mode 100644 index 00000000..96a23869 --- /dev/null +++ b/examples/tensorflow_v2_mnist.py @@ -0,0 +1,130 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import tensorflow as tf + +from tensorflow.keras.layers import Dense, Flatten, Conv2D +from tensorflow.keras import Model + +from trains import Task + + +task = Task.init(project_name='examples', + task_name='Tensorflow v2 mnist with summaries') + + +# Load and prepare the MNIST dataset. +mnist = tf.keras.datasets.mnist + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train, x_test = x_train / 255.0, x_test / 255.0 + +# Add a channels dimension +x_train = x_train[..., tf.newaxis] +x_test = x_test[..., tf.newaxis] + +# Use tf.data to batch and shuffle the dataset +train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32) +test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) + + +# Build the tf.keras model using the Keras model subclassing API +class MyModel(Model): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = Conv2D(32, 3, activation='relu') + self.flatten = Flatten() + self.d1 = Dense(128, activation='relu') + self.d2 = Dense(10, activation='softmax') + + def call(self, x): + x = self.conv1(x) + x = self.flatten(x) + x = self.d1(x) + return self.d2(x) + + +# Create an instance of the model +model = MyModel() + +# Choose an optimizer and loss function for training +loss_object = tf.keras.losses.SparseCategoricalCrossentropy() +optimizer = tf.keras.optimizers.Adam() + +# Select metrics to measure the loss and the accuracy of the model. +# These metrics accumulate the values over epochs and then print the overall result. +train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32) +train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') + +test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32) +test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') + +# Use tf.GradientTape to train the model +@tf.function +def train_step(images, labels): + with tf.GradientTape() as tape: + predictions = model(images) + loss = loss_object(labels, predictions) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + train_loss(loss) + train_accuracy(labels, predictions) + + +# Test the model +@tf.function +def test_step(images, labels): + predictions = model(images) + t_loss = loss_object(labels, predictions) + + test_loss(t_loss) + test_accuracy(labels, predictions) + + +# Set up summary writers to write the summaries to disk in a different logs directory +train_log_dir = '/tmp/logs/gradient_tape/train' +test_log_dir = '/tmp/logs/gradient_tape/test' +train_summary_writer = tf.summary.create_file_writer(train_log_dir) +test_summary_writer = tf.summary.create_file_writer(test_log_dir) + +# Set up checkpoints manager +ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) +manager = tf.train.CheckpointManager(ckpt, '/tmp/tf_ckpts', max_to_keep=3) +ckpt.restore(manager.latest_checkpoint) +if manager.latest_checkpoint: + print("Restored from {}".format(manager.latest_checkpoint)) +else: + print("Initializing from scratch.") + +# Start training +EPOCHS = 5 +for epoch in range(EPOCHS): + for images, labels in train_ds: + train_step(images, labels) + with train_summary_writer.as_default(): + tf.summary.scalar('loss', train_loss.result(), step=epoch) + tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch) + + ckpt.step.assign_add(1) + if int(ckpt.step) % 1 == 0: + save_path = manager.save() + print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path)) + + for test_images, test_labels in test_ds: + test_step(test_images, test_labels) + with test_summary_writer.as_default(): + tf.summary.scalar('loss', test_loss.result(), step=epoch) + tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch) + + template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}' + print(template.format(epoch+1, + train_loss.result(), + train_accuracy.result()*100, + test_loss.result(), + test_accuracy.result()*100)) + + # Reset the metrics for the next epoch + train_loss.reset_states() + train_accuracy.reset_states() + test_loss.reset_states() + test_accuracy.reset_states() diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 63f6d8f2..0535b9c5 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -8,6 +8,7 @@ from io import BytesIO from typing import Any import numpy as np +import six from PIL import Image from ...debugging.log import LoggerRoot @@ -22,6 +23,16 @@ except ImportError: MessageToDict = None +class TensorflowBinding(object): + @classmethod + def update_current_task(cls, task): + PatchSummaryToEventTransformer.update_current_task(task) + PatchTensorFlowEager.update_current_task(task) + PatchKerasModelIO.update_current_task(task) + PatchTensorflowModelIO.update_current_task(task) + PatchTensorflow2ModelIO.update_current_task(task) + + class IsTensorboardInit(object): _tensorboard_initialized = False @@ -156,6 +167,7 @@ class EventTrainsWriter(object): self.image_report_freq = image_report_freq if image_report_freq else report_freq self.histogram_granularity = histogram_granularity self.histogram_update_freq_multiplier = histogram_update_freq_multiplier + self._histogram_update_call_counter = 0 self._logger = logger self._visualization_mode = 'RGB' # 'BGR' self._variants = defaultdict(lambda: ()) @@ -168,12 +180,18 @@ class EventTrainsWriter(object): def _decode_image(self, img_str, width, height, color_channels): # noinspection PyBroadException try: - imdata = base64.b64decode(img_str) + if isinstance(img_str, bytes): + imdata = img_str + else: + imdata = base64.b64decode(img_str) output = BytesIO(imdata) im = Image.open(output) image = np.asarray(im) output.close() - val = image.reshape(height, width, -1).astype(np.uint8) + if height > 0 and width > 0: + val = image.reshape(height, width, -1).astype(np.uint8) + else: + val = image.astype(np.uint8) if val.ndim == 3 and val.shape[2] == 3: if self._visualization_mode == 'BGR': val = val[:, :, [2, 1, 0]] @@ -187,7 +205,7 @@ class EventTrainsWriter(object): else: val = val[:, :, [0, 1, 2]] except Exception: - LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]' + LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]' % (width, height, color_channels)) val = None return val @@ -281,7 +299,9 @@ class EventTrainsWriter(object): return _cur_idx # only collect histogram every specific interval - if step % self.report_freq != 0 or step < self.report_freq - 1: + self._histogram_update_call_counter += 1 + if self._histogram_update_call_counter % self.report_freq != 0 or \ + self._histogram_update_call_counter < self.report_freq - 1: return None # generate forward matrix of the histograms @@ -306,6 +326,8 @@ class EventTrainsWriter(object): # add current sample, if not already here hist_iters = np.append(hist_iters, step) + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32) hist = hist[~np.isinf(hist[:, 0]), :] hist_list.append(hist) @@ -422,7 +444,7 @@ class EventTrainsWriter(object): msg_dict.pop('wallTime', None) keys_list = [key for key in msg_dict.keys() if len(key) > 0] keys_list = ', '.join(keys_list) - LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list) + LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list) return value_dicts = summary.get('value') walltime = walltime or msg_dict.get('step') @@ -434,19 +456,19 @@ class EventTrainsWriter(object): step = int(event.step) else: step = 0 - LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step)) + LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step)) else: step = int(step) self._max_step = max(self._max_step, step) if value_dicts is None: - LoggerRoot.get_base_logger().debug("Summary arrived without 'value'") + LoggerRoot.get_base_logger(TensorflowBinding).debug("Summary arrived without 'value'") return for vdict in value_dicts: tag = vdict.pop('tag', None) if tag is None: # we should not get here - LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s' + LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s' % ', '.join(vdict.keys())) continue metric, values = get_data(vdict, supported_metrics) @@ -463,7 +485,7 @@ class EventTrainsWriter(object): elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': self._add_plot(tag, step, values, vdict) else: - LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]' + LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys))) continue @@ -594,7 +616,7 @@ class PatchSummaryToEventTransformer(object): setattr(SummaryToEventTransformer, 'trains', property(PatchSummaryToEventTransformer.trains_object)) except Exception as ex: - LoggerRoot.get_base_logger().debug(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) if 'torch' in sys.modules: try: @@ -608,7 +630,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - LoggerRoot.get_base_logger().debug(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) if 'tensorboardX' in sys.modules: try: @@ -624,7 +646,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - LoggerRoot.get_base_logger().debug(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) if PatchSummaryToEventTransformer.__original_getattributeX is None: try: @@ -638,7 +660,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - LoggerRoot.get_base_logger().debug(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) @staticmethod def _patched_add_eventT(self, *args, **kwargs): @@ -722,7 +744,7 @@ class _ModelAdapter(object): super(_ModelAdapter, self).__init__() super(_ModelAdapter, self).__setattr__('_model', model) super(_ModelAdapter, self).__setattr__('_output_model', output_model) - super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger()) + super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger(TensorflowBinding)) def __getattr__(self, attr): return getattr(self._model, attr) @@ -805,7 +827,7 @@ class PatchModelCheckPointCallback(object): property(PatchModelCheckPointCallback.trains_object)) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @staticmethod def _patched_getattribute(self, attr): @@ -876,6 +898,8 @@ class PatchTensorFlowEager(object): gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary + PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary + gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__, gen_summary_ops.create_summary_file_writer) gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__, @@ -883,7 +907,7 @@ class PatchTensorFlowEager(object): except ImportError: pass except Exception as ex: - LoggerRoot.get_base_logger().debug(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) @staticmethod def _get_event_writer(writer): @@ -903,14 +927,40 @@ class PatchTensorFlowEager(object): def trains_object(self): return PatchTensorFlowEager.__trains_event_writer + @staticmethod + def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs): + event_writer = PatchTensorFlowEager._get_event_writer(writer) + if event_writer: + try: + plugin_type = summary_metadata.decode() + if plugin_type.endswith('scalars'): + event_writer._add_scalar(tag=str(tag), + step=int(step.numpy()) if not isinstance(step, int) else step, + scalar_data=tensor.numpy()) + elif plugin_type.endswith('images'): + img_data_np = tensor.numpy() + PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np, + tag=tag, step=step, **kwargs) + elif plugin_type.endswith('histograms'): + PatchTensorFlowEager._add_histogram_event_helper( + event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + hist_data=tensor.numpy()) + else: + pass # print('unsupported plugin_type', plugin_type) + except Exception as ex: + pass + return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs) + @staticmethod def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: - event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) + event_writer._add_scalar(tag=str(tag), + step=int(step.numpy()) if not isinstance(step, int) else step, + scalar_data=value.numpy()) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs) @staticmethod @@ -918,9 +968,11 @@ class PatchTensorFlowEager(object): event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: - event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) + PatchTensorFlowEager._add_histogram_event_helper( + event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + hist_data=values.numpy()) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) @staticmethod @@ -928,13 +980,45 @@ class PatchTensorFlowEager(object): event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: - event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), - max_keep_images=max_images) + PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(), + tag=tag, step=step, **kwargs) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name, **kwargs) + @staticmethod + def _add_histogram_event_helper(event_writer, hist_data, tag, step): + if isinstance(hist_data, dict): + event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data) + return + + # prepare the dictionary, assume numpy + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis + # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side + histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} + event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data) + + @staticmethod + def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs): + if img_data_np.ndim == 1 and img_data_np.size >= 3 and \ + (len(img_data_np[0]) < 10 and len(img_data_np[1]) < 10): + # this is just for making sure these are actually valid numbers + width = int(img_data_np[0].decode()) + height = int(img_data_np[1].decode()) + for i in range(2, img_data_np.size): + img_data = {'width': -1, 'height': -1, + 'colorspace': 'RGB', 'encodedImageString': img_data_np[i]} + image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag) + event_writer._add_image(tag=image_tag, + step=int(step.numpy()) if not isinstance(step, int) else step, + img_data=img_data) + else: + event_writer._add_image_numpy(tag=str(tag), + step=int(step.numpy()) if not isinstance(step, int) else step, + img_data_np=img_data_np, + max_keep_images=kwargs.get('max_images')) class PatchKerasModelIO(object): __main_task = None @@ -1029,7 +1113,7 @@ class PatchKerasModelIO(object): keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @staticmethod def _updated_config(original_fn, self): @@ -1057,7 +1141,7 @@ class PatchKerasModelIO(object): framework=Framework.keras, ) except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return config @@ -1107,7 +1191,7 @@ class PatchKerasModelIO(object): return model except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return self @@ -1189,7 +1273,7 @@ class PatchKerasModelIO(object): # if anyone asks, we were here self.trains_out_model._processed = True except Exception as ex: - LoggerRoot.get_base_logger().warning(str(ex)) + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @staticmethod def _save_model(original_fn, model, filepath, *args, **kwargs): @@ -1261,7 +1345,7 @@ class PatchTensorflowModelIO(object): except ImportError: pass except Exception: - pass # print('Failed patching tensorflow') + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow') # noinspection PyBroadException try: @@ -1282,10 +1366,9 @@ class PatchTensorflowModelIO(object): saved_model = None except Exception: saved_model = None - pass # print('Failed patching tensorflow') + except Exception: saved_model = None - pass # print('Failed patching tensorflow') if saved_model is not None: saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model) @@ -1301,7 +1384,7 @@ class PatchTensorflowModelIO(object): except ImportError: pass except Exception: - pass # print('Failed patching tensorflow') + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow') # noinspection PyBroadException try: @@ -1313,7 +1396,7 @@ class PatchTensorflowModelIO(object): except ImportError: pass except Exception: - pass # print('Failed patching tensorflow') + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow') # noinspection PyBroadException try: @@ -1325,7 +1408,7 @@ class PatchTensorflowModelIO(object): except ImportError: pass except Exception: - pass # print('Failed patching tensorflow') + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow') # noinspection PyBroadException try: @@ -1349,7 +1432,7 @@ class PatchTensorflowModelIO(object): except ImportError: pass except Exception: - pass # print('Failed patching tensorflow') + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow') @staticmethod def _save(original_fn, self, sess, save_path, *args, **kwargs): @@ -1457,3 +1540,81 @@ class PatchTensorflowModelIO(object): pass return model + +class PatchTensorflow2ModelIO(object): + __main_task = None + __patched = None + + @staticmethod + def update_current_task(task, **kwargs): + PatchTensorflow2ModelIO.__main_task = task + PatchTensorflow2ModelIO._patch_model_checkpoint() + PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint) + + @staticmethod + def _patch_model_checkpoint(): + if PatchTensorflow2ModelIO.__patched: + return + + if 'tensorflow' not in sys.modules: + return + + PatchTensorflow2ModelIO.__patched = True + # noinspection PyBroadException + try: + # hack: make sure tensorflow.__init__ is called + import tensorflow + from tensorflow.python.training.tracking import util + # noinspection PyBroadException + try: + util.TrackableSaver.save = _patched_call(util.TrackableSaver.save, + PatchTensorflow2ModelIO._save) + except Exception: + pass + # noinspection PyBroadException + try: + util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore, + PatchTensorflow2ModelIO._restore) + except Exception: + pass + except ImportError: + pass + except Exception: + LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2') + + @staticmethod + def _save(original_fn, self, file_prefix, *args, **kwargs): + model = original_fn(self, file_prefix, *args, **kwargs) + # store output Model + try: + WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow, + PatchTensorflow2ModelIO.__main_task) + except Exception: + pass + return model + + @staticmethod + def _restore(original_fn, self, save_path, *args, **kwargs): + if PatchTensorflow2ModelIO.__main_task is None: + return original_fn(self, save_path, *args, **kwargs) + + # Hack: disabled + if False and running_remotely(): + # register/load model weights + try: + save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, + PatchTensorflow2ModelIO.__main_task) + except Exception: + pass + # load model + return original_fn(self, save_path, *args, **kwargs) + + # load model, if something is wrong, exception will be raised before we register the input model + model = original_fn(self, save_path, *args, **kwargs) + # register/load model weights + try: + WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, + PatchTensorflow2ModelIO.__main_task) + except Exception: + pass + return model diff --git a/trains/config/default/sdk.conf b/trains/config/default/sdk.conf index b0af0358..96ddf0b4 100644 --- a/trains/config/default/sdk.conf +++ b/trains/config/default/sdk.conf @@ -8,7 +8,7 @@ } direct_access: [ - # Objects matching are considered to be available for direct access, i.e. they will not eb downloaded + # Objects matching are considered to be available for direct access, i.e. they will not be downloaded # or cached, and any download request will return a direct reference. # Objects are specified in glob format, available for url and content_type. { url: "file://*" } # file-urls are always directly referenced diff --git a/trains/task.py b/trains/task.py index c91a9b7e..24dde7c1 100644 --- a/trains/task.py +++ b/trains/task.py @@ -34,8 +34,7 @@ from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO -from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \ - PatchKerasModelIO, PatchTensorflowModelIO +from .binding.frameworks.tensorflow_bind import TensorflowBinding from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.matplotlib_bind import PatchedMatplotlib from .utilities.resource_monitor import ResourceMonitor @@ -252,10 +251,7 @@ class Task(_Task): PatchedJoblib.update_current_task(task) PatchedMatplotlib.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task) - PatchSummaryToEventTransformer.update_current_task(task) - PatchTensorFlowEager.update_current_task(task) - PatchKerasModelIO.update_current_task(task) - PatchTensorflowModelIO.update_current_task(task) + TensorflowBinding.update_current_task(task) PatchPyTorchModelIO.update_current_task(task) PatchXGBoostModelIO.update_current_task(task) if auto_resource_monitoring: @@ -346,7 +342,7 @@ class Task(_Task): def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id): if not default_project_name or not default_task_name: # get project name and task name from repository name and entry_point - result = ScriptInfo.get(create_requirements=False, check_uncommitted=False) + result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False) if not default_project_name: # noinspection PyBroadException try: @@ -1016,7 +1012,14 @@ class Task(_Task): if self._detect_repo_async_thread: try: if self._detect_repo_async_thread.is_alive(): + self.log.info('Waiting for repository detection and full package requirement analysis') self._detect_repo_async_thread.join(timeout=timeout) + # because join has no return value + if self._detect_repo_async_thread.is_alive(): + self.log.info('Repository and package analysis timed out ({} sec), ' + 'giving up'.format(timeout)) + else: + self.log.info('Finished repository detection and package analysis') self._detect_repo_async_thread = None except Exception: pass