mirror of
https://github.com/clearml/clearml
synced 2025-04-29 10:43:16 +00:00
commit
9e166137b6
@ -31,7 +31,7 @@ import tensorflow as tf
|
|||||||
from tensorflow.examples.tutorials.mnist import input_data
|
from tensorflow.examples.tutorials.mnist import input_data
|
||||||
from trains import Task
|
from trains import Task
|
||||||
|
|
||||||
tf.enable_eager_execution()
|
tf.compat.v1.enable_eager_execution()
|
||||||
|
|
||||||
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
|
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
|
||||||
|
|
||||||
@ -160,11 +160,11 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
|
|||||||
A scalar loss Tensor.
|
A scalar loss Tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss_on_real = tf.losses.sigmoid_cross_entropy(
|
loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
|
||||||
tf.ones_like(discriminator_real_outputs),
|
tf.ones_like(discriminator_real_outputs),
|
||||||
discriminator_real_outputs,
|
discriminator_real_outputs,
|
||||||
label_smoothing=0.25)
|
label_smoothing=0.25)
|
||||||
loss_on_generated = tf.losses.sigmoid_cross_entropy(
|
loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
|
||||||
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
||||||
loss = loss_on_real + loss_on_generated
|
loss = loss_on_real + loss_on_generated
|
||||||
tf.contrib.summary.scalar('discriminator_loss', loss)
|
tf.contrib.summary.scalar('discriminator_loss', loss)
|
||||||
@ -182,7 +182,7 @@ def generator_loss(discriminator_gen_outputs):
|
|||||||
Returns:
|
Returns:
|
||||||
A scalar loss Tensor.
|
A scalar loss Tensor.
|
||||||
"""
|
"""
|
||||||
loss = tf.losses.sigmoid_cross_entropy(
|
loss = tf.compat.v1.losses.sigmoid_cross_entropy(
|
||||||
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
||||||
tf.contrib.summary.scalar('generator_loss', loss)
|
tf.contrib.summary.scalar('generator_loss', loss)
|
||||||
return loss
|
return loss
|
||||||
@ -208,12 +208,12 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
|
|||||||
total_discriminator_loss = 0.0
|
total_discriminator_loss = 0.0
|
||||||
for (batch_index, images) in enumerate(dataset):
|
for (batch_index, images) in enumerate(dataset):
|
||||||
with tf.device('/cpu:0'):
|
with tf.device('/cpu:0'):
|
||||||
tf.assign_add(step_counter, 1)
|
tf.compat.v1.assign_add(step_counter, 1)
|
||||||
|
|
||||||
with tf.contrib.summary.record_summaries_every_n_global_steps(
|
with tf.contrib.summary.record_summaries_every_n_global_steps(
|
||||||
log_interval, global_step=step_counter):
|
log_interval, global_step=step_counter):
|
||||||
current_batch_size = images.shape[0]
|
current_batch_size = images.shape[0]
|
||||||
noise = tf.random_uniform(
|
noise = tf.random.uniform(
|
||||||
shape=[current_batch_size, noise_dim],
|
shape=[current_batch_size, noise_dim],
|
||||||
minval=-1.,
|
minval=-1.,
|
||||||
maxval=1.,
|
maxval=1.,
|
||||||
@ -271,9 +271,9 @@ def main(_):
|
|||||||
model_objects = {
|
model_objects = {
|
||||||
'generator': Generator(data_format),
|
'generator': Generator(data_format),
|
||||||
'discriminator': Discriminator(data_format),
|
'discriminator': Discriminator(data_format),
|
||||||
'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
|
'generator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
|
||||||
'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
|
'discriminator_optimizer': tf.compat.v1.train.AdamOptimizer(FLAGS.lr),
|
||||||
'step_counter': tf.train.get_or_create_global_step(),
|
'step_counter': tf.compat.v1.train.get_or_create_global_step(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prepare summary writer and checkpoint info
|
# Prepare summary writer and checkpoint info
|
||||||
@ -355,4 +355,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
|
||||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
@ -154,7 +154,7 @@ class EventTrainsWriter(object):
|
|||||||
self.histogram_granularity = histogram_granularity
|
self.histogram_granularity = histogram_granularity
|
||||||
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
self.histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._visualization_mode = 'RGB' # 'BGR'
|
self._visualization_mode = 'BGR'
|
||||||
self._variants = defaultdict(lambda: ())
|
self._variants = defaultdict(lambda: ())
|
||||||
self._scalar_report_cache = {}
|
self._scalar_report_cache = {}
|
||||||
self._hist_report_cache = {}
|
self._hist_report_cache = {}
|
||||||
@ -582,7 +582,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
setattr(SummaryToEventTransformer, 'trains',
|
setattr(SummaryToEventTransformer, 'trains',
|
||||||
property(PatchSummaryToEventTransformer.trains_object))
|
property(PatchSummaryToEventTransformer.trains_object))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
try:
|
try:
|
||||||
@ -596,7 +596,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
if 'tensorboardX' in sys.modules:
|
if 'tensorboardX' in sys.modules:
|
||||||
try:
|
try:
|
||||||
@ -612,7 +612,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||||
try:
|
try:
|
||||||
@ -626,7 +626,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# this is a new version of TensorflowX
|
# this is a new version of TensorflowX
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_add_eventT(self, *args, **kwargs):
|
def _patched_add_eventT(self, *args, **kwargs):
|
||||||
@ -871,7 +871,7 @@ class PatchTensorFlowEager(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_event_writer(writer):
|
def _get_event_writer(writer):
|
||||||
@ -1244,13 +1244,12 @@ class PatchTensorflowModelIO(object):
|
|||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
from tensorflow.saved_model.experimental import save
|
from tf.saved_model import save
|
||||||
# actual import
|
# actual import
|
||||||
import tensorflow.saved_model.experimental as saved_model
|
import tensorflow.saved_model.experimental as saved_model
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# TODO: we might want to reverse the order, so we do not get the deprecated warning
|
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
from tensorflow.saved_model import save
|
from tensorflow.saved_model import save
|
||||||
|
@ -64,11 +64,11 @@ def make_deterministic(seed=1337, cudnn_deterministic=False):
|
|||||||
|
|
||||||
if not eager_mode_bypass:
|
if not eager_mode_bypass:
|
||||||
try:
|
try:
|
||||||
tf.set_random_seed(seed)
|
tf.compat.v1.set_random_seed(seed)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
tf.random.set_random_seed(seed)
|
tf.compat.v1.set_random_seed(seed)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user