mirror of
https://github.com/clearml/clearml
synced 2025-06-16 03:19:48 +00:00
Fix #1054, now retring a pipeline step will continue from the correct tf epoch
This commit is contained in:
parent
44faf6ef7b
commit
e7edcbb813
@ -724,16 +724,7 @@ class EventTrainsWriter(object):
|
|||||||
'Received event without step, assuming step = {}'.format(step))
|
'Received event without step, assuming step = {}'.format(step))
|
||||||
else:
|
else:
|
||||||
step = int(step)
|
step = int(step)
|
||||||
# unlike other frameworks, tensorflow already accounts for the iteration number
|
step = tweak_step(step)
|
||||||
# when continuing the training. we substract the smallest iteration such that we
|
|
||||||
# don't increment the step twice number
|
|
||||||
original_step = step
|
|
||||||
if EventTrainsWriter._current_task:
|
|
||||||
step -= EventTrainsWriter._current_task.get_initial_iteration()
|
|
||||||
# there can be a few metrics getting reported again, so the step can be negative
|
|
||||||
# for the first few reports
|
|
||||||
if step < 0 and original_step > 0:
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
self._max_step = max(self._max_step, step)
|
self._max_step = max(self._max_step, step)
|
||||||
if value_dicts is None:
|
if value_dicts is None:
|
||||||
@ -1378,7 +1369,7 @@ class PatchTensorFlowEager(object):
|
|||||||
plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):]
|
plugin_type = plugin_type[next(i for i, c in enumerate(plugin_type) if c >= 'A'):]
|
||||||
if plugin_type.startswith('scalars'):
|
if plugin_type.startswith('scalars'):
|
||||||
event_writer._add_scalar(tag=str(tag),
|
event_writer._add_scalar(tag=str(tag),
|
||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=tweak_step(step),
|
||||||
scalar_data=tensor.numpy())
|
scalar_data=tensor.numpy())
|
||||||
elif plugin_type.startswith('images'):
|
elif plugin_type.startswith('images'):
|
||||||
img_data_np = tensor.numpy()
|
img_data_np = tensor.numpy()
|
||||||
@ -1386,19 +1377,19 @@ class PatchTensorFlowEager(object):
|
|||||||
tag=tag, step=step, **kwargs)
|
tag=tag, step=step, **kwargs)
|
||||||
elif plugin_type.startswith('histograms'):
|
elif plugin_type.startswith('histograms'):
|
||||||
event_writer._add_histogram(
|
event_writer._add_histogram(
|
||||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=tweak_step(step),
|
||||||
hist_data=tensor.numpy()
|
hist_data=tensor.numpy()
|
||||||
)
|
)
|
||||||
elif plugin_type.startswith('text'):
|
elif plugin_type.startswith('text'):
|
||||||
event_writer._add_text(
|
event_writer._add_text(
|
||||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=tweak_step(step),
|
||||||
tensor_bytes=tensor.numpy()
|
tensor_bytes=tensor.numpy()
|
||||||
)
|
)
|
||||||
elif 'audio' in plugin_type:
|
elif 'audio' in plugin_type:
|
||||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||||
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
|
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
|
||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=tweak_step(step),
|
||||||
values=None, audio_data=audio_bytes)
|
values=None, audio_data=audio_bytes)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -1416,7 +1407,7 @@ class PatchTensorFlowEager(object):
|
|||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
try:
|
try:
|
||||||
event_writer._add_scalar(tag=str(tag),
|
event_writer._add_scalar(tag=str(tag),
|
||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=tweak_step(step),
|
||||||
scalar_data=value.numpy())
|
scalar_data=value.numpy())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
@ -1428,7 +1419,7 @@ class PatchTensorFlowEager(object):
|
|||||||
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
||||||
event_writer._add_scalar(
|
event_writer._add_scalar(
|
||||||
tag=str_tag,
|
tag=str_tag,
|
||||||
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
|
step=tweak_step(step),
|
||||||
scalar_data=a_value.numpy())
|
scalar_data=a_value.numpy())
|
||||||
except Exception as a_ex:
|
except Exception as a_ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(
|
||||||
@ -1458,7 +1449,7 @@ class PatchTensorFlowEager(object):
|
|||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
try:
|
try:
|
||||||
event_writer._add_histogram(
|
event_writer._add_histogram(
|
||||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=tweak_step(step),
|
||||||
hist_data=values.numpy()
|
hist_data=values.numpy()
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -1471,7 +1462,7 @@ class PatchTensorFlowEager(object):
|
|||||||
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
|
||||||
event_writer._add_histogram(
|
event_writer._add_histogram(
|
||||||
tag=str_tag,
|
tag=str_tag,
|
||||||
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
|
step=tweak_step(a_step),
|
||||||
hist_data=a_value.numpy()
|
hist_data=a_value.numpy()
|
||||||
)
|
)
|
||||||
except Exception as a_ex:
|
except Exception as a_ex:
|
||||||
@ -1549,11 +1540,11 @@ class PatchTensorFlowEager(object):
|
|||||||
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
||||||
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
|
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
|
||||||
event_writer._add_image(tag=image_tag,
|
event_writer._add_image(tag=image_tag,
|
||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=tweak_step(step),
|
||||||
img_data=img_data)
|
img_data=img_data)
|
||||||
else:
|
else:
|
||||||
event_writer._add_image_numpy(tag=str(tag),
|
event_writer._add_image_numpy(tag=str(tag),
|
||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=tweak_step(step),
|
||||||
img_data_np=img_data_np,
|
img_data_np=img_data_np,
|
||||||
max_keep_images=kwargs.get('max_images'))
|
max_keep_images=kwargs.get('max_images'))
|
||||||
|
|
||||||
@ -2299,3 +2290,15 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tweak_step(step):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
step = int(step.numpy()) if not isinstance(step, int) else step
|
||||||
|
# unlike other frameworks, tensorflow already accounts for the iteration number
|
||||||
|
# when continuing the training. we substract the smallest iteration such that we
|
||||||
|
# don't increment the step twice number
|
||||||
|
return step - EventTrainsWriter._current_task.get_initial_iteration()
|
||||||
|
except Exception:
|
||||||
|
return step
|
||||||
|
Loading…
Reference in New Issue
Block a user