From a0e42c66a11809d2197207c335c66b27567937c7 Mon Sep 17 00:00:00 2001 From: Erez Schanider Date: Wed, 17 Jul 2019 09:10:32 +0300 Subject: [PATCH] fix deprecation warning --- examples/tensorflow_eager.py | 20 ++++++++++---------- trains/binding/frameworks/tensorflow_bind.py | 15 +++++++-------- trains/utilities/seed.py | 4 ++-- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/examples/tensorflow_eager.py b/examples/tensorflow_eager.py index 5a9599a5..407e184a 100644 --- a/examples/tensorflow_eager.py +++ b/examples/tensorflow_eager.py @@ -31,7 +31,7 @@ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from trains import Task -tf.enable_eager_execution() +tf.compat.v1.enable_eager_execution() task = Task.init(project_name='examples', task_name='Tensorflow eager mode') @@ -160,11 +160,11 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): A scalar loss Tensor. """ - loss_on_real = tf.losses.sigmoid_cross_entropy( + loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy( tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, label_smoothing=0.25) - loss_on_generated = tf.losses.sigmoid_cross_entropy( + loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) loss = loss_on_real + loss_on_generated tf.contrib.summary.scalar('discriminator_loss', loss) @@ -182,7 +182,7 @@ def generator_loss(discriminator_gen_outputs): Returns: A scalar loss Tensor. """ - loss = tf.losses.sigmoid_cross_entropy( + loss = tf.compat.v1.losses.sigmoid_cross_entropy( tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) tf.contrib.summary.scalar('generator_loss', loss) return loss @@ -208,12 +208,12 @@ def train_one_epoch(generator, discriminator, generator_optimizer, total_discriminator_loss = 0.0 for (batch_index, images) in enumerate(dataset): with tf.device('/cpu:0'): - tf.assign_add(step_counter, 1) + tf.compat.v1.assign_add(step_counter, 1) with tf.contrib.summary.record_summaries_every_n_global_steps( log_interval, global_step=step_counter): current_batch_size = images.shape[0] - noise = tf.random_uniform( + noise = tf.random.uniform( shape=[current_batch_size, noise_dim], minval=-1., maxval=1., @@ -271,9 +271,9 @@ def main(_): model_objects = { 'generator': Generator(data_format), 'discriminator': Discriminator(data_format), - 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), - 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), - 'step_counter': tf.train.get_or_create_global_step(), + 'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr), + 'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr), + 'step_counter': tf.compat.v1.train.get_or_create_global_step(), } # Prepare summary writer and checkpoint info @@ -355,4 +355,4 @@ if __name__ == '__main__': FLAGS, unparsed = parser.parse_known_args() -tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) +tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 1f8697cd..b2d42104 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -93,7 +93,7 @@ class EventTrainsWriter(object): self.histogram_granularity = histogram_granularity self.histogram_update_freq_multiplier = histogram_update_freq_multiplier self._logger = logger - self._visualization_mode = 'RGB' # 'BGR' + self._visualization_mode = 'BGR' self._variants = defaultdict(lambda: ()) self._scalar_report_cache = {} self._hist_report_cache = {} @@ -519,7 +519,7 @@ class PatchSummaryToEventTransformer(object): setattr(SummaryToEventTransformer, 'trains', property(PatchSummaryToEventTransformer.trains_object)) except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + getLogger(TrainsFrameworkAdapter).warning(str(ex)) if 'torch' in sys.modules: try: @@ -533,7 +533,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + getLogger(TrainsFrameworkAdapter).warning(str(ex)) if 'tensorboardX' in sys.modules: try: @@ -549,7 +549,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + getLogger(TrainsFrameworkAdapter).warning(str(ex)) if PatchSummaryToEventTransformer.__original_getattributeX is None: try: @@ -563,7 +563,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + getLogger(TrainsFrameworkAdapter).warning(str(ex)) @staticmethod def _patched_add_eventT(self, *args, **kwargs): @@ -795,7 +795,7 @@ class PatchTensorFlowEager(object): except ImportError: pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + getLogger(TrainsFrameworkAdapter).warning(str(ex)) @staticmethod def _get_event_writer(): @@ -1163,13 +1163,12 @@ class PatchTensorflowModelIO(object): try: # make sure we import the correct version of save import tensorflow - from tensorflow.saved_model.experimental import save + from tf.saved_model import save # actual import import tensorflow.saved_model.experimental as saved_model except ImportError: # noinspection PyBroadException try: - # TODO: we might want to reverse the order, so we do not get the deprecated warning # make sure we import the correct version of save import tensorflow from tensorflow.saved_model import save diff --git a/trains/utilities/seed.py b/trains/utilities/seed.py index fff6c169..808d91c0 100644 --- a/trains/utilities/seed.py +++ b/trains/utilities/seed.py @@ -64,11 +64,11 @@ def make_deterministic(seed=1337, cudnn_deterministic=False): if not eager_mode_bypass: try: - tf.set_random_seed(seed) + tf.compat.v1.set_random_seed(seed) except Exception: pass try: - tf.random.set_random_seed(seed) + tf.compat.v1.set_random_seed(seed) except Exception: pass