Fix #1054, now retring a pipeline step will continue from the correct tf epoch

This commit is contained in:
Alex Burlacu 2023-07-14 14:04:38 +03:00
parent 44faf6ef7b
commit e7edcbb813

View File

@ -724,16 +724,7 @@ class EventTrainsWriter(object):
'Received event without step, assuming step = {}'.format(step)) 'Received event without step, assuming step = {}'.format(step))
else: else:
step = int(step) step = int(step)
# unlike other frameworks, tensorflow already accounts for the iteration number step = tweak_step(step)
# when continuing the training. we substract the smallest iteration such that we
# don't increment the step twice number
original_step = step
if EventTrainsWriter._current_task:
step -= EventTrainsWriter._current_task.get_initial_iteration()
# there can be a few metrics getting reported again, so the step can be negative
# for the first few reports
if step < 0 and original_step > 0:
step = 0
self._max_step = max(self._max_step, step) self._max_step = max(self._max_step, step)
if value_dicts is None: if value_dicts is None:
@ -1378,7 +1369,7 @@ class PatchTensorFlowEager(object):
plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):] plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):]
if plugin_type.startswith('scalars'): if plugin_type.startswith('scalars'):
event_writer._add_scalar(tag=str(tag), event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step, step=tweak_step(step),
scalar_data=tensor.numpy()) scalar_data=tensor.numpy())
elif plugin_type.startswith('images'): elif plugin_type.startswith('images'):
img_data_np = tensor.numpy() img_data_np = tensor.numpy()
@ -1386,19 +1377,19 @@ class PatchTensorFlowEager(object):
tag=tag, step=step, **kwargs) tag=tag, step=step, **kwargs)
elif plugin_type.startswith('histograms'): elif plugin_type.startswith('histograms'):
event_writer._add_histogram( event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, tag=str(tag), step=tweak_step(step),
hist_data=tensor.numpy() hist_data=tensor.numpy()
) )
elif plugin_type.startswith('text'): elif plugin_type.startswith('text'):
event_writer._add_text( event_writer._add_text(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, tag=str(tag), step=tweak_step(step),
tensor_bytes=tensor.numpy() tensor_bytes=tensor.numpy()
) )
elif 'audio' in plugin_type: elif 'audio' in plugin_type:
audio_bytes_list = [a for a in tensor.numpy().flatten() if a] audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
for i, audio_bytes in enumerate(audio_bytes_list): for i, audio_bytes in enumerate(audio_bytes_list):
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''), event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
step=int(step.numpy()) if not isinstance(step, int) else step, step=tweak_step(step),
values=None, audio_data=audio_bytes) values=None, audio_data=audio_bytes)
else: else:
pass pass
@ -1416,7 +1407,7 @@ class PatchTensorFlowEager(object):
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try: try:
event_writer._add_scalar(tag=str(tag), event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step, step=tweak_step(step),
scalar_data=value.numpy()) scalar_data=value.numpy())
except Exception as ex: except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1428,7 +1419,7 @@ class PatchTensorFlowEager(object):
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_scalar( event_writer._add_scalar(
tag=str_tag, tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, step=tweak_step(step),
scalar_data=a_value.numpy()) scalar_data=a_value.numpy())
except Exception as a_ex: except Exception as a_ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning( LoggerRoot.get_base_logger(TensorflowBinding).warning(
@ -1458,7 +1449,7 @@ class PatchTensorFlowEager(object):
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try: try:
event_writer._add_histogram( event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, tag=str(tag), step=tweak_step(step),
hist_data=values.numpy() hist_data=values.numpy()
) )
except Exception as ex: except Exception as ex:
@ -1471,7 +1462,7 @@ class PatchTensorFlowEager(object):
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_histogram( event_writer._add_histogram(
tag=str_tag, tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, step=tweak_step(a_step),
hist_data=a_value.numpy() hist_data=a_value.numpy()
) )
except Exception as a_ex: except Exception as a_ex:
@ -1549,11 +1540,11 @@ class PatchTensorFlowEager(object):
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]} 'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag) image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
event_writer._add_image(tag=image_tag, event_writer._add_image(tag=image_tag,
step=int(step.numpy()) if not isinstance(step, int) else step, step=tweak_step(step),
img_data=img_data) img_data=img_data)
else: else:
event_writer._add_image_numpy(tag=str(tag), event_writer._add_image_numpy(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step, step=tweak_step(step),
img_data_np=img_data_np, img_data_np=img_data_np,
max_keep_images=kwargs.get('max_images')) max_keep_images=kwargs.get('max_images'))
@ -2299,3 +2290,15 @@ class PatchTensorflow2ModelIO(object):
except Exception: except Exception:
pass pass
return model return model
def tweak_step(step):
# noinspection PyBroadException
try:
step = int(step.numpy()) if not isinstance(step, int) else step
# unlike other frameworks, tensorflow already accounts for the iteration number
# when continuing the training. we substract the smallest iteration such that we
# don't increment the step twice number
return step - EventTrainsWriter._current_task.get_initial_iteration()
except Exception:
return step