Merge pull request #25 from erezalg/master

fix deprecation warning
This commit is contained in:
Allegro AI 2019-07-24 01:10:35 +03:00 committed by GitHub
commit 9e166137b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 20 deletions

View File

@ -31,7 +31,7 @@ import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
from trains import Task from trains import Task
tf.enable_eager_execution() tf.compat.v1.enable_eager_execution()
task = Task.init(project_name='examples', task_name='Tensorflow eager mode') task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
@ -160,11 +160,11 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
A scalar loss Tensor. A scalar loss Tensor.
""" """
loss_on_real = tf.losses.sigmoid_cross_entropy( loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
tf.ones_like(discriminator_real_outputs), tf.ones_like(discriminator_real_outputs),
discriminator_real_outputs, discriminator_real_outputs,
label_smoothing=0.25) label_smoothing=0.25)
loss_on_generated = tf.losses.sigmoid_cross_entropy( loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
loss = loss_on_real + loss_on_generated loss = loss_on_real + loss_on_generated
tf.contrib.summary.scalar('discriminator_loss', loss) tf.contrib.summary.scalar('discriminator_loss', loss)
@ -182,7 +182,7 @@ def generator_loss(discriminator_gen_outputs):
Returns: Returns:
A scalar loss Tensor. A scalar loss Tensor.
""" """
loss = tf.losses.sigmoid_cross_entropy( loss = tf.compat.v1.losses.sigmoid_cross_entropy(
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
tf.contrib.summary.scalar('generator_loss', loss) tf.contrib.summary.scalar('generator_loss', loss)
return loss return loss
@ -208,12 +208,12 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
total_discriminator_loss = 0.0 total_discriminator_loss = 0.0
for (batch_index, images) in enumerate(dataset): for (batch_index, images) in enumerate(dataset):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
tf.assign_add(step_counter, 1) tf.compat.v1.assign_add(step_counter, 1)
with tf.contrib.summary.record_summaries_every_n_global_steps( with tf.contrib.summary.record_summaries_every_n_global_steps(
log_interval, global_step=step_counter): log_interval, global_step=step_counter):
current_batch_size = images.shape[0] current_batch_size = images.shape[0]
noise = tf.random_uniform( noise = tf.random.uniform(
shape=[current_batch_size, noise_dim], shape=[current_batch_size, noise_dim],
minval=-1., minval=-1.,
maxval=1., maxval=1.,
@ -271,9 +271,9 @@ def main(_):
model_objects = { model_objects = {
'generator': Generator(data_format), 'generator': Generator(data_format),
'discriminator': Discriminator(data_format), 'discriminator': Discriminator(data_format),
'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), 'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), 'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
'step_counter': tf.train.get_or_create_global_step(), 'step_counter': tf.compat.v1.train.get_or_create_global_step(),
} }
# Prepare summary writer and checkpoint info # Prepare summary writer and checkpoint info
@ -355,4 +355,4 @@ if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -154,7 +154,7 @@ class EventTrainsWriter(object):
self.histogram_granularity = histogram_granularity self.histogram_granularity = histogram_granularity
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
self._logger = logger self._logger = logger
self._visualization_mode = 'RGB' # 'BGR' self._visualization_mode = 'BGR'
self._variants = defaultdict(lambda: ()) self._variants = defaultdict(lambda: ())
self._scalar_report_cache = {} self._scalar_report_cache = {}
self._hist_report_cache = {} self._hist_report_cache = {}
@ -582,7 +582,7 @@ class PatchSummaryToEventTransformer(object):
setattr(SummaryToEventTransformer, 'trains', setattr(SummaryToEventTransformer, 'trains',
property(PatchSummaryToEventTransformer.trains_object)) property(PatchSummaryToEventTransformer.trains_object))
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
if 'torch' in sys.modules: if 'torch' in sys.modules:
try: try:
@ -596,7 +596,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
if 'tensorboardX' in sys.modules: if 'tensorboardX' in sys.modules:
try: try:
@ -612,7 +612,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
if PatchSummaryToEventTransformer.__original_getattributeX is None: if PatchSummaryToEventTransformer.__original_getattributeX is None:
try: try:
@ -626,7 +626,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX # this is a new version of TensorflowX
pass pass
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod @staticmethod
def _patched_add_eventT(self, *args, **kwargs): def _patched_add_eventT(self, *args, **kwargs):
@ -871,7 +871,7 @@ class PatchTensorFlowEager(object):
except ImportError: except ImportError:
pass pass
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod @staticmethod
def _get_event_writer(writer): def _get_event_writer(writer):
@ -1244,13 +1244,12 @@ class PatchTensorflowModelIO(object):
try: try:
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
from tensorflow.saved_model.experimental import save from tf.saved_model import save
# actual import # actual import
import tensorflow.saved_model.experimental as saved_model import tensorflow.saved_model.experimental as saved_model
except ImportError: except ImportError:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# TODO: we might want to reverse the order, so we do not get the deprecated warning
# make sure we import the correct version of save # make sure we import the correct version of save
import tensorflow import tensorflow
from tensorflow.saved_model import save from tensorflow.saved_model import save

View File

@ -64,11 +64,11 @@ def make_deterministic(seed=1337, cudnn_deterministic=False):
if not eager_mode_bypass: if not eager_mode_bypass:
try: try:
tf.set_random_seed(seed) tf.compat.v1.set_random_seed(seed)
except Exception: except Exception:
pass pass
try: try:
tf.random.set_random_seed(seed) tf.compat.v1.set_random_seed(seed)
except Exception: except Exception:
pass pass