Add initial Tensorflow v2 support (2.0.0rc1)

This commit is contained in:
allegroai 2019-09-27 13:24:04 +03:00
parent c44638c8d9
commit a7eb8476ce
4 changed files with 336 additions and 42 deletions

View File

@ -0,0 +1,130 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
from trains import Task
task = Task.init(project_name='examples',
task_name='Tensorflow v2 mnist with summaries')
# Load and prepare the MNIST dataset.
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# Use tf.data to batch and shuffle the dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Build the tf.keras model using the Keras model subclassing API
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
# Choose an optimizer and loss function for training
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
# Select metrics to measure the loss and the accuracy of the model.
# These metrics accumulate the values over epochs and then print the overall result.
train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# Use tf.GradientTape to train the model
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
# Test the model
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
# Set up summary writers to write the summaries to disk in a different logs directory
train_log_dir = '/tmp/logs/gradient_tape/train'
test_log_dir = '/tmp/logs/gradient_tape/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
# Set up checkpoints manager
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model)
manager = tf.train.CheckpointManager(ckpt, '/tmp/tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
# Start training
EPOCHS = 5
for epoch in range(EPOCHS):
for images, labels in train_ds:
train_step(images, labels)
with train_summary_writer.as_default():
tf.summary.scalar('loss', train_loss.result(), step=epoch)
tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
ckpt.step.assign_add(1)
if int(ckpt.step) % 1 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
with test_summary_writer.as_default():
tf.summary.scalar('loss', test_loss.result(), step=epoch)
tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch+1,
train_loss.result(),
train_accuracy.result()*100,
test_loss.result(),
test_accuracy.result()*100))
# Reset the metrics for the next epoch
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()

View File

@ -8,6 +8,7 @@ from io import BytesIO
from typing import Any
import numpy as np
import six
from PIL import Image
from ...debugging.log import LoggerRoot
@ -22,6 +23,16 @@ except ImportError:
MessageToDict = None
class TensorflowBinding(object):
@classmethod
def update_current_task(cls, task):
PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task)
PatchTensorflow2ModelIO.update_current_task(task)
class IsTensorboardInit(object):
_tensorboard_initialized = False
@ -156,6 +167,7 @@ class EventTrainsWriter(object):
self.image_report_freq = image_report_freq if image_report_freq else report_freq
self.histogram_granularity = histogram_granularity
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
self._histogram_update_call_counter = 0
self._logger = logger
self._visualization_mode = 'RGB' # 'BGR'
self._variants = defaultdict(lambda: ())
@ -168,12 +180,18 @@ class EventTrainsWriter(object):
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
try:
imdata = base64.b64decode(img_str)
if isinstance(img_str, bytes):
imdata = img_str
else:
imdata = base64.b64decode(img_str)
output = BytesIO(imdata)
im = Image.open(output)
image = np.asarray(im)
output.close()
val = image.reshape(height, width, -1).astype(np.uint8)
if height > 0 and width > 0:
val = image.reshape(height, width, -1).astype(np.uint8)
else:
val = image.astype(np.uint8)
if val.ndim == 3 and val.shape[2] == 3:
if self._visualization_mode == 'BGR':
val = val[:, :, [2, 1, 0]]
@ -187,7 +205,7 @@ class EventTrainsWriter(object):
else:
val = val[:, :, [0, 1, 2]]
except Exception:
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]'
LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
% (width, height, color_channels))
val = None
return val
@ -281,7 +299,9 @@ class EventTrainsWriter(object):
return _cur_idx
# only collect histogram every specific interval
if step % self.report_freq != 0 or step < self.report_freq - 1:
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
# generate forward matrix of the histograms
@ -306,6 +326,8 @@ class EventTrainsWriter(object):
# add current sample, if not already here
hist_iters = np.append(hist_iters, step)
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist)
@ -422,7 +444,7 @@ class EventTrainsWriter(object):
msg_dict.pop('wallTime', None)
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
keys_list = ', '.join(keys_list)
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list)
LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
return
value_dicts = summary.get('value')
walltime = walltime or msg_dict.get('step')
@ -434,19 +456,19 @@ class EventTrainsWriter(object):
step = int(event.step)
else:
step = 0
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step))
LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
else:
step = int(step)
self._max_step = max(self._max_step, step)
if value_dicts is None:
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'")
LoggerRoot.get_base_logger(TensorflowBinding).debug("Summary arrived without 'value'")
return
for vdict in value_dicts:
tag = vdict.pop('tag', None)
if tag is None:
# we should not get here
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s'
LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
% ', '.join(vdict.keys()))
continue
metric, values = get_data(vdict, supported_metrics)
@ -463,7 +485,7 @@ class EventTrainsWriter(object):
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict)
else:
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]'
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
% (tag, ', '.join(vdict.keys)))
continue
@ -594,7 +616,7 @@ class PatchSummaryToEventTransformer(object):
setattr(SummaryToEventTransformer, 'trains',
property(PatchSummaryToEventTransformer.trains_object))
except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if 'torch' in sys.modules:
try:
@ -608,7 +630,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if 'tensorboardX' in sys.modules:
try:
@ -624,7 +646,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
if PatchSummaryToEventTransformer.__original_getattributeX is None:
try:
@ -638,7 +660,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
@staticmethod
def _patched_add_eventT(self, *args, **kwargs):
@ -722,7 +744,7 @@ class _ModelAdapter(object):
super(_ModelAdapter, self).__init__()
super(_ModelAdapter, self).__setattr__('_model', model)
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger())
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger(TensorflowBinding))
def __getattr__(self, attr):
return getattr(self._model, attr)
@ -805,7 +827,7 @@ class PatchModelCheckPointCallback(object):
property(PatchModelCheckPointCallback.trains_object))
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod
def _patched_getattribute(self, attr):
@ -876,6 +898,8 @@ class PatchTensorFlowEager(object):
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
gen_summary_ops.create_summary_file_writer)
gen_summary_ops.create_summary_db_writer = partial(IsTensorboardInit._patched_tb__init__,
@ -883,7 +907,7 @@ class PatchTensorFlowEager(object):
except ImportError:
pass
except Exception as ex:
LoggerRoot.get_base_logger().debug(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
@staticmethod
def _get_event_writer(writer):
@ -903,14 +927,40 @@ class PatchTensorFlowEager(object):
def trains_object(self):
return PatchTensorFlowEager.__trains_event_writer
@staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
plugin_type = summary_metadata.decode()
if plugin_type.endswith('scalars'):
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
scalar_data=tensor.numpy())
elif plugin_type.endswith('images'):
img_data_np = tensor.numpy()
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
tag=tag, step=step, **kwargs)
elif plugin_type.endswith('histograms'):
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy())
else:
pass # print('unsupported plugin_type', plugin_type)
except Exception as ex:
pass
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
scalar_data=value.numpy())
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
@staticmethod
@ -918,9 +968,11 @@ class PatchTensorFlowEager(object):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=values.numpy())
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@staticmethod
@ -928,13 +980,45 @@ class PatchTensorFlowEager(object):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
max_keep_images=max_images)
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(),
tag=tag, step=step, **kwargs)
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs)
@staticmethod
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
if isinstance(hist_data, dict):
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
return
# prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
@staticmethod
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
(len(img_data_np[0]) < 10 and len(img_data_np[1]) < 10):
# this is just for making sure these are actually valid numbers
width = int(img_data_np[0].decode())
height = int(img_data_np[1].decode())
for i in range(2, img_data_np.size):
img_data = {'width': -1, 'height': -1,
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
event_writer._add_image(tag=image_tag,
step=int(step.numpy()) if not isinstance(step, int) else step,
img_data=img_data)
else:
event_writer._add_image_numpy(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
img_data_np=img_data_np,
max_keep_images=kwargs.get('max_images'))
class PatchKerasModelIO(object):
__main_task = None
@ -1029,7 +1113,7 @@ class PatchKerasModelIO(object):
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod
def _updated_config(original_fn, self):
@ -1057,7 +1141,7 @@ class PatchKerasModelIO(object):
framework=Framework.keras,
)
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return config
@ -1107,7 +1191,7 @@ class PatchKerasModelIO(object):
return model
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return self
@ -1189,7 +1273,7 @@ class PatchKerasModelIO(object):
# if anyone asks, we were here
self.trains_out_model._processed = True
except Exception as ex:
LoggerRoot.get_base_logger().warning(str(ex))
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs):
@ -1261,7 +1345,7 @@ class PatchTensorflowModelIO(object):
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
@ -1282,10 +1366,9 @@ class PatchTensorflowModelIO(object):
saved_model = None
except Exception:
saved_model = None
pass # print('Failed patching tensorflow')
except Exception:
saved_model = None
pass # print('Failed patching tensorflow')
if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
@ -1301,7 +1384,7 @@ class PatchTensorflowModelIO(object):
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
@ -1313,7 +1396,7 @@ class PatchTensorflowModelIO(object):
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
@ -1325,7 +1408,7 @@ class PatchTensorflowModelIO(object):
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
@ -1349,7 +1432,7 @@ class PatchTensorflowModelIO(object):
except ImportError:
pass
except Exception:
pass # print('Failed patching tensorflow')
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
@staticmethod
def _save(original_fn, self, sess, save_path, *args, **kwargs):
@ -1457,3 +1540,81 @@ class PatchTensorflowModelIO(object):
pass
return model
class PatchTensorflow2ModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchTensorflow2ModelIO.__main_task = task
PatchTensorflow2ModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
if PatchTensorflow2ModelIO.__patched:
return
if 'tensorflow' not in sys.modules:
return
PatchTensorflow2ModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.tracking import util
# noinspection PyBroadException
try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
PatchTensorflow2ModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
PatchTensorflow2ModelIO._restore)
except Exception:
pass
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
@staticmethod
def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs)
# store output Model
try:
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
return model
@staticmethod
def _restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflow2ModelIO.__main_task is None:
return original_fn(self, save_path, *args, **kwargs)
# Hack: disabled
if False and running_remotely():
# register/load model weights
try:
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
# load model
return original_fn(self, save_path, *args, **kwargs)
# load model, if something is wrong, exception will be raised before we register the input model
model = original_fn(self, save_path, *args, **kwargs)
# register/load model weights
try:
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
except Exception:
pass
return model

View File

@ -8,7 +8,7 @@
}
direct_access: [
# Objects matching are considered to be available for direct access, i.e. they will not eb downloaded
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
# or cached, and any download request will return a direct reference.
# Objects are specified in glob format, available for url and content_type.
{ url: "file://*" } # file-urls are always directly referenced

View File

@ -34,8 +34,7 @@ from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
PatchKerasModelIO, PatchTensorflowModelIO
from .binding.frameworks.tensorflow_bind import TensorflowBinding
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.resource_monitor import ResourceMonitor
@ -252,10 +251,7 @@ class Task(_Task):
PatchedJoblib.update_current_task(task)
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task)
TensorflowBinding.update_current_task(task)
PatchPyTorchModelIO.update_current_task(task)
PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring:
@ -346,7 +342,7 @@ class Task(_Task):
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name:
# noinspection PyBroadException
try:
@ -1016,7 +1012,14 @@ class Task(_Task):
if self._detect_repo_async_thread:
try:
if self._detect_repo_async_thread.is_alive():
self.log.info('Waiting for repository detection and full package requirement analysis')
self._detect_repo_async_thread.join(timeout=timeout)
# because join has no return value
if self._detect_repo_async_thread.is_alive():
self.log.info('Repository and package analysis timed out ({} sec), '
'giving up'.format(timeout))
else:
self.log.info('Finished repository detection and package analysis')
self._detect_repo_async_thread = None
except Exception:
pass