diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py index 0f08883e..d551d84a 100644 --- a/trains/backend_interface/metrics/events.py +++ b/trains/backend_interface/metrics/events.py @@ -91,10 +91,12 @@ class MetricsEventAdapter(object): def get_iteration(self): return self._iter - def update(self, task=None, **kwargs): + def update(self, task=None, iter_offset=None, **kwargs): """ Update event properties """ if task: self._task = task + if iter_offset is not None and self._iter is not None: + self._iter += iter_offset def _get_base_dict(self): """ Get a dict with the base attributes """ diff --git a/trains/backend_interface/metrics/interface.py b/trains/backend_interface/metrics/interface.py index 69f2a96a..ad665aa5 100644 --- a/trains/backend_interface/metrics/interface.py +++ b/trains/backend_interface/metrics/interface.py @@ -44,9 +44,10 @@ class Metrics(InterfaceBase): finally: self._storage_lock.release() - def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', log=None): + def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None): super(Metrics, self).__init__(session, log=log) self._task_id = task_id + self._task_iteration_offset = iteration_offset self._storage_uri = storage_uri.rstrip('/') if storage_uri else None self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None self._file_related_event_time = None @@ -81,6 +82,12 @@ class Metrics(InterfaceBase): args=(events, storage_uri), callback=partial(self._callback_wrapper, callback)) + def set_iteration_offset(self, offset): + self._task_iteration_offset = offset + + def get_iteration_offset(self): + return self._task_iteration_offset + def _callback_wrapper(self, callback, res): """ A wrapper for the async callback for handling common errors """ if not res: @@ -129,7 +136,7 @@ class Metrics(InterfaceBase): if not hasattr(entry.stream, 'read'): raise ValueError('Invalid file object %s' % entry.stream) entry.url = url - ev.update(task=self._task_id, **kwargs) + ev.update(task=self._task_id, iter_offset=self._task_iteration_offset, **kwargs) return entry # prepare event needing file upload diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 4095903e..b1be0bc5 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -104,7 +104,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan def _report(self, ev): ev_iteration = ev.get_iteration() if ev_iteration is not None: - self._max_iteration = max(self._max_iteration, ev_iteration) + # we have to manually add get_iteration_offset() because event hasn't reached the Metric manager + self._max_iteration = max(self._max_iteration, ev_iteration + self._metrics.get_iteration_offset()) self._events.append(ev) if len(self._events) >= self._flush_threshold: self.flush() diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index e2d67f7b..eebcb366 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -99,6 +99,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): ) self._app_server = None self._files_server = None + self._initial_iteration_offset = 0 if not task_id: # generate a new task @@ -398,7 +399,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): session=self.session, task_id=self.id, storage_uri=storage_uri, - storage_uri_suffix=self._get_output_destination_suffix('metrics') + storage_uri_suffix=self._get_output_destination_suffix('metrics'), + iteration_offset=self.get_initial_iteration() ) return self._metrics_manager @@ -821,6 +823,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._set_task_property("comment", str(comment)) self._edit(comment=comment) + def set_initial_iteration(self, offset=0): + """ + Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints + + :param int offset: Initial iteration (at starting point) + :return: newly set initial offset + """ + if not isinstance(offset, int): + raise ValueError("Initial iteration offset must be an integer") + + self._initial_iteration_offset = offset + if self._metrics_manager: + self._metrics_manager.set_iteration_offset(self._initial_iteration_offset) + return self._initial_iteration_offset + + def get_initial_iteration(self): + """ + Return the initial iteration offset, default is 0. + Useful when continuing training from previous checkpoints. + + :return int: initial iteration offset + """ + return self._initial_iteration_offset + def _get_default_report_storage_uri(self): if not self._files_server: self._files_server = Session.get_files_server_host() diff --git a/trains/task.py b/trains/task.py index 2aba754b..fffead93 100644 --- a/trains/task.py +++ b/trains/task.py @@ -78,7 +78,7 @@ class Task(_Task): NotSet = object() __create_protection = object() - __main_task = None + __main_task = None # type: Task __exit_hook = None __forked_proc_main_pid = None __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) @@ -887,6 +887,24 @@ class Task(_Task): self.data.last_iteration = int(last_iteration) self._edit(last_iteration=self.data.last_iteration) + def set_initial_iteration(self, offset=0): + """ + Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints + + :param int offset: Initial iteration (at starting point) + :return: newly set initial offset + """ + return super(Task, self).set_initial_iteration(offset=offset) + + def get_initial_iteration(self): + """ + Return the initial iteration offset, default is 0. + Useful when continuing training from previous checkpoints. + + :return int: initial iteration offset + """ + return super(Task, self).get_initial_iteration() + def get_last_scalar_metrics(self): """ Extract the last scalar metrics, ordered by title & series in a nested dictionary