Fix #1054, now retring a pipeline step will continue from the correct tf epoch

This commit is contained in:
Alex Burlacu 2023-07-14 14:04:38 +03:00
parent 44faf6ef7b
commit e7edcbb813

View File

@ -724,16 +724,7 @@ class EventTrainsWriter(object):
'Received event without step, assuming step = {}'.format(step))
else:
step = int(step)
# unlike other frameworks, tensorflow already accounts for the iteration number
# when continuing the training. we substract the smallest iteration such that we
# don't increment the step twice number
original_step = step
if EventTrainsWriter._current_task:
step -= EventTrainsWriter._current_task.get_initial_iteration()
# there can be a few metrics getting reported again, so the step can be negative
# for the first few reports
if step < 0 and original_step > 0:
step = 0
step = tweak_step(step)
self._max_step = max(self._max_step, step)
if value_dicts is None:
@ -1378,7 +1369,7 @@ class PatchTensorFlowEager(object):
plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):]
if plugin_type.startswith('scalars'):
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
step=tweak_step(step),
scalar_data=tensor.numpy())
elif plugin_type.startswith('images'):
img_data_np = tensor.numpy()
@ -1386,19 +1377,19 @@ class PatchTensorFlowEager(object):
tag=tag, step=step, **kwargs)
elif plugin_type.startswith('histograms'):
event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
tag=str(tag), step=tweak_step(step),
hist_data=tensor.numpy()
)
elif plugin_type.startswith('text'):
event_writer._add_text(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
tag=str(tag), step=tweak_step(step),
tensor_bytes=tensor.numpy()
)
elif 'audio' in plugin_type:
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
for i, audio_bytes in enumerate(audio_bytes_list):
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
step=int(step.numpy()) if not isinstance(step, int) else step,
step=tweak_step(step),
values=None, audio_data=audio_bytes)
else:
pass
@ -1416,7 +1407,7 @@ class PatchTensorFlowEager(object):
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
step=tweak_step(step),
scalar_data=value.numpy())
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
@ -1428,7 +1419,7 @@ class PatchTensorFlowEager(object):
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_scalar(
tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
step=tweak_step(step),
scalar_data=a_value.numpy())
except Exception as a_ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(
@ -1458,7 +1449,7 @@ class PatchTensorFlowEager(object):
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
tag=str(tag), step=tweak_step(step),
hist_data=values.numpy()
)
except Exception as ex:
@ -1471,7 +1462,7 @@ class PatchTensorFlowEager(object):
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_histogram(
tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
step=tweak_step(a_step),
hist_data=a_value.numpy()
)
except Exception as a_ex:
@ -1549,11 +1540,11 @@ class PatchTensorFlowEager(object):
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
event_writer._add_image(tag=image_tag,
step=int(step.numpy()) if not isinstance(step, int) else step,
step=tweak_step(step),
img_data=img_data)
else:
event_writer._add_image_numpy(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
step=tweak_step(step),
img_data_np=img_data_np,
max_keep_images=kwargs.get('max_images'))
@ -2299,3 +2290,15 @@ class PatchTensorflow2ModelIO(object):
except Exception:
pass
return model
def tweak_step(step):
# noinspection PyBroadException
try:
step = int(step.numpy()) if not isinstance(step, int) else step
# unlike other frameworks, tensorflow already accounts for the iteration number
# when continuing the training. we substract the smallest iteration such that we
# don't increment the step twice number
return step - EventTrainsWriter._current_task.get_initial_iteration()
except Exception:
return step