mirror of
https://github.com/clearml/clearml
synced 2025-04-27 01:39:17 +00:00
Fix #1054, now retring a pipeline step will continue from the correct tf epoch
This commit is contained in:
parent
44faf6ef7b
commit
e7edcbb813
@ -724,16 +724,7 @@ class EventTrainsWriter(object):
|
||||
'Received event without step, assuming step = {}'.format(step))
|
||||
else:
|
||||
step = int(step)
|
||||
# unlike other frameworks, tensorflow already accounts for the iteration number
|
||||
# when continuing the training. we substract the smallest iteration such that we
|
||||
# don't increment the step twice number
|
||||
original_step = step
|
||||
if EventTrainsWriter._current_task:
|
||||
step -= EventTrainsWriter._current_task.get_initial_iteration()
|
||||
# there can be a few metrics getting reported again, so the step can be negative
|
||||
# for the first few reports
|
||||
if step < 0 and original_step > 0:
|
||||
step = 0
|
||||
step = tweak_step(step)
|
||||
|
||||
self._max_step = max(self._max_step, step)
|
||||
if value_dicts is None:
|
||||
@ -1378,7 +1369,7 @@ class PatchTensorFlowEager(object):
|
||||
plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):]
|
||||
if plugin_type.startswith('scalars'):
|
||||
event_writer._add_scalar(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
step=tweak_step(step),
|
||||
scalar_data=tensor.numpy())
|
||||
elif plugin_type.startswith('images'):
|
||||
img_data_np = tensor.numpy()
|
||||
@ -1386,19 +1377,19 @@ class PatchTensorFlowEager(object):
|
||||
tag=tag, step=step, **kwargs)
|
||||
elif plugin_type.startswith('histograms'):
|
||||
event_writer._add_histogram(
|
||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
tag=str(tag), step=tweak_step(step),
|
||||
hist_data=tensor.numpy()
|
||||
)
|
||||
elif plugin_type.startswith('text'):
|
||||
event_writer._add_text(
|
||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
tag=str(tag), step=tweak_step(step),
|
||||
tensor_bytes=tensor.numpy()
|
||||
)
|
||||
elif 'audio' in plugin_type:
|
||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
step=tweak_step(step),
|
||||
values=None, audio_data=audio_bytes)
|
||||
else:
|
||||
pass
|
||||
@ -1416,7 +1407,7 @@ class PatchTensorFlowEager(object):
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
try:
|
||||
event_writer._add_scalar(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
step=tweak_step(step),
|
||||
scalar_data=value.numpy())
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
@ -1428,7 +1419,7 @@ class PatchTensorFlowEager(object):
|
||||
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
||||
event_writer._add_scalar(
|
||||
tag=str_tag,
|
||||
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
|
||||
step=tweak_step(step),
|
||||
scalar_data=a_value.numpy())
|
||||
except Exception as a_ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(
|
||||
@ -1458,7 +1449,7 @@ class PatchTensorFlowEager(object):
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
try:
|
||||
event_writer._add_histogram(
|
||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
tag=str(tag), step=tweak_step(step),
|
||||
hist_data=values.numpy()
|
||||
)
|
||||
except Exception as ex:
|
||||
@ -1471,7 +1462,7 @@ class PatchTensorFlowEager(object):
|
||||
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
||||
event_writer._add_histogram(
|
||||
tag=str_tag,
|
||||
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
|
||||
step=tweak_step(a_step),
|
||||
hist_data=a_value.numpy()
|
||||
)
|
||||
except Exception as a_ex:
|
||||
@ -1549,11 +1540,11 @@ class PatchTensorFlowEager(object):
|
||||
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
||||
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
|
||||
event_writer._add_image(tag=image_tag,
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
step=tweak_step(step),
|
||||
img_data=img_data)
|
||||
else:
|
||||
event_writer._add_image_numpy(tag=str(tag),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
step=tweak_step(step),
|
||||
img_data_np=img_data_np,
|
||||
max_keep_images=kwargs.get('max_images'))
|
||||
|
||||
@ -2299,3 +2290,15 @@ class PatchTensorflow2ModelIO(object):
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
def tweak_step(step):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
step = int(step.numpy()) if not isinstance(step, int) else step
|
||||
# unlike other frameworks, tensorflow already accounts for the iteration number
|
||||
# when continuing the training. we substract the smallest iteration such that we
|
||||
# don't increment the step twice number
|
||||
return step - EventTrainsWriter._current_task.get_initial_iteration()
|
||||
except Exception:
|
||||
return step
|
||||
|
Loading…
Reference in New Issue
Block a user