Support setting task initial iteration for continuing previous runs

This commit is contained in:
allegroai 2020-03-12 17:40:29 +02:00
parent f3531c1af2
commit b3dff9a4eb
5 changed files with 60 additions and 6 deletions

View File

@ -91,10 +91,12 @@ class MetricsEventAdapter(object):
def get_iteration(self):
return self._iter
def update(self, task=None, **kwargs):
def update(self, task=None, iter_offset=None, **kwargs):
""" Update event properties """
if task:
self._task = task
if iter_offset is not None and self._iter is not None:
self._iter += iter_offset
def _get_base_dict(self):
""" Get a dict with the base attributes """

View File

@ -44,9 +44,10 @@ class Metrics(InterfaceBase):
finally:
self._storage_lock.release()
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', log=None):
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None):
super(Metrics, self).__init__(session, log=log)
self._task_id = task_id
self._task_iteration_offset = iteration_offset
self._storage_uri = storage_uri.rstrip('/') if storage_uri else None
self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None
self._file_related_event_time = None
@ -81,6 +82,12 @@ class Metrics(InterfaceBase):
args=(events, storage_uri),
callback=partial(self._callback_wrapper, callback))
def set_iteration_offset(self, offset):
self._task_iteration_offset = offset
def get_iteration_offset(self):
return self._task_iteration_offset
def _callback_wrapper(self, callback, res):
""" A wrapper for the async callback for handling common errors """
if not res:
@ -129,7 +136,7 @@ class Metrics(InterfaceBase):
if not hasattr(entry.stream, 'read'):
raise ValueError('Invalid file object %s' % entry.stream)
entry.url = url
ev.update(task=self._task_id, **kwargs)
ev.update(task=self._task_id, iter_offset=self._task_iteration_offset, **kwargs)
return entry
# prepare event needing file upload

View File

@ -104,7 +104,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
def _report(self, ev):
ev_iteration = ev.get_iteration()
if ev_iteration is not None:
self._max_iteration = max(self._max_iteration, ev_iteration)
# we have to manually add get_iteration_offset() because event hasn't reached the Metric manager
self._max_iteration = max(self._max_iteration, ev_iteration + self._metrics.get_iteration_offset())
self._events.append(ev)
if len(self._events) >= self._flush_threshold:
self.flush()

View File

@ -99,6 +99,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
)
self._app_server = None
self._files_server = None
self._initial_iteration_offset = 0
if not task_id:
# generate a new task
@ -398,7 +399,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
session=self.session,
task_id=self.id,
storage_uri=storage_uri,
storage_uri_suffix=self._get_output_destination_suffix('metrics')
storage_uri_suffix=self._get_output_destination_suffix('metrics'),
iteration_offset=self.get_initial_iteration()
)
return self._metrics_manager
@ -821,6 +823,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._set_task_property("comment", str(comment))
self._edit(comment=comment)
def set_initial_iteration(self, offset=0):
"""
Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints
:param int offset: Initial iteration (at starting point)
:return: newly set initial offset
"""
if not isinstance(offset, int):
raise ValueError("Initial iteration offset must be an integer")
self._initial_iteration_offset = offset
if self._metrics_manager:
self._metrics_manager.set_iteration_offset(self._initial_iteration_offset)
return self._initial_iteration_offset
def get_initial_iteration(self):
"""
Return the initial iteration offset, default is 0.
Useful when continuing training from previous checkpoints.
:return int: initial iteration offset
"""
return self._initial_iteration_offset
def _get_default_report_storage_uri(self):
if not self._files_server:
self._files_server = Session.get_files_server_host()

View File

@ -78,7 +78,7 @@ class Task(_Task):
NotSet = object()
__create_protection = object()
__main_task = None
__main_task = None # type: Task
__exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
@ -887,6 +887,24 @@ class Task(_Task):
self.data.last_iteration = int(last_iteration)
self._edit(last_iteration=self.data.last_iteration)
def set_initial_iteration(self, offset=0):
"""
Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints
:param int offset: Initial iteration (at starting point)
:return: newly set initial offset
"""
return super(Task, self).set_initial_iteration(offset=offset)
def get_initial_iteration(self):
"""
Return the initial iteration offset, default is 0.
Useful when continuing training from previous checkpoints.
:return int: initial iteration offset
"""
return super(Task, self).get_initial_iteration()
def get_last_scalar_metrics(self):
"""
Extract the last scalar metrics, ordered by title & series in a nested dictionary