diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index fa0ac995..65b9e06d 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -724,16 +724,7 @@ class EventTrainsWriter(object): 'Received event without step, assuming step = {}'.format(step)) else: step = int(step) - # unlike other frameworks, tensorflow already accounts for the iteration number - # when continuing the training. we substract the smallest iteration such that we - # don't increment the step twice number - original_step = step - if EventTrainsWriter._current_task: - step -= EventTrainsWriter._current_task.get_initial_iteration() - # there can be a few metrics getting reported again, so the step can be negative - # for the first few reports - if step < 0 and original_step > 0: - step = 0 + step = tweak_step(step) self._max_step = max(self._max_step, step) if value_dicts is None: @@ -1378,7 +1369,7 @@ class PatchTensorFlowEager(object): plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):] if plugin_type.startswith('scalars'): event_writer._add_scalar(tag=str(tag), - step=int(step.numpy()) if not isinstance(step, int) else step, + step=tweak_step(step), scalar_data=tensor.numpy()) elif plugin_type.startswith('images'): img_data_np = tensor.numpy() @@ -1386,19 +1377,19 @@ class PatchTensorFlowEager(object): tag=tag, step=step, **kwargs) elif plugin_type.startswith('histograms'): event_writer._add_histogram( - tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + tag=str(tag), step=tweak_step(step), hist_data=tensor.numpy() ) elif plugin_type.startswith('text'): event_writer._add_text( - tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + tag=str(tag), step=tweak_step(step), tensor_bytes=tensor.numpy() ) elif 'audio' in plugin_type: audio_bytes_list = [a for a in tensor.numpy().flatten() if a] for i, audio_bytes in enumerate(audio_bytes_list): event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''), - step=int(step.numpy()) if not isinstance(step, int) else step, + step=tweak_step(step), values=None, audio_data=audio_bytes) else: pass @@ -1416,7 +1407,7 @@ class PatchTensorFlowEager(object): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: event_writer._add_scalar(tag=str(tag), - step=int(step.numpy()) if not isinstance(step, int) else step, + step=tweak_step(step), scalar_data=value.numpy()) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1428,7 +1419,7 @@ class PatchTensorFlowEager(object): str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) event_writer._add_scalar( tag=str_tag, - step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, + step=tweak_step(step), scalar_data=a_value.numpy()) except Exception as a_ex: LoggerRoot.get_base_logger(TensorflowBinding).warning( @@ -1458,7 +1449,7 @@ class PatchTensorFlowEager(object): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: event_writer._add_histogram( - tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + tag=str(tag), step=tweak_step(step), hist_data=values.numpy() ) except Exception as ex: @@ -1471,7 +1462,7 @@ class PatchTensorFlowEager(object): str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) event_writer._add_histogram( tag=str_tag, - step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, + step=tweak_step(a_step), hist_data=a_value.numpy() ) except Exception as a_ex: @@ -1549,11 +1540,11 @@ class PatchTensorFlowEager(object): 'colorspace': 'RGB', 'encodedImageString': img_data_np[i]} image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag) event_writer._add_image(tag=image_tag, - step=int(step.numpy()) if not isinstance(step, int) else step, + step=tweak_step(step), img_data=img_data) else: event_writer._add_image_numpy(tag=str(tag), - step=int(step.numpy()) if not isinstance(step, int) else step, + step=tweak_step(step), img_data_np=img_data_np, max_keep_images=kwargs.get('max_images')) @@ -2299,3 +2290,15 @@ class PatchTensorflow2ModelIO(object): except Exception: pass return model + + +def tweak_step(step): + # noinspection PyBroadException + try: + step = int(step.numpy()) if not isinstance(step, int) else step + # unlike other frameworks, tensorflow already accounts for the iteration number + # when continuing the training. we substract the smallest iteration such that we + # don't increment the step twice number + return step - EventTrainsWriter._current_task.get_initial_iteration() + except Exception: + return step