From 295b33857ccf0881e1e24309b7055cbdfe6e4cfd Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 29 Sep 2020 19:28:42 +0300 Subject: [PATCH] Fix optimizer monitor --- trains/automation/optimization.py | 10 +++++----- trains/automation/optuna/optuna.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/trains/automation/optimization.py b/trains/automation/optimization.py index 2f931dbe..15564fd4 100644 --- a/trains/automation/optimization.py +++ b/trains/automation/optimization.py @@ -1026,17 +1026,17 @@ class HyperParameterOptimizer(object): self._thread_reporter.start() return True - def stop(self, timeout=None, flush_reporter=True): + def stop(self, timeout=None, wait_for_reporter=True): # type: (Optional[float], Optional[bool]) -> () """ Stop the HyperParameterOptimizer controller and the optimization thread. :param float timeout: Wait timeout for the optimization thread to exit (minutes). The default is ``None``, indicating do not wait terminate immediately. - :param flush_reporter: Wait for reporter to flush data. + :param wait_for_reporter: Wait for reporter to flush data. """ if not self._thread or not self._stop_event or not self.optimizer: - if self._thread_reporter and flush_reporter: + if self._thread_reporter and wait_for_reporter: self._thread_reporter.join() return @@ -1054,7 +1054,7 @@ class HyperParameterOptimizer(object): # clear thread self._thread = None - if flush_reporter: + if wait_for_reporter: # wait for reporter to flush self._thread_reporter.join() @@ -1311,7 +1311,7 @@ class HyperParameterOptimizer(object): # if we should leave, stop everything now. if timeout < 0: # we should leave - self.stop(flush_reporter=False) + self.stop(wait_for_reporter=False) return if task_logger and counter: counter += 1 diff --git a/trains/automation/optuna/optuna.py b/trains/automation/optuna/optuna.py index ab6565c9..b6e8f12e 100644 --- a/trains/automation/optuna/optuna.py +++ b/trains/automation/optuna/optuna.py @@ -45,12 +45,11 @@ class OptunaObjective(object): current_job.launch(self.queue_name) iteration_value = None is_pending = True - while not current_job.is_stopped(): + while True: if is_pending and not current_job.is_pending(): is_pending = False self.optimizer.budget.jobs.update(current_job.task_id(), 1.) if not is_pending: - self.optimizer.update_budget_per_job(current_job) # noinspection PyProtectedMember iteration_value = self.optimizer._objective_metric.get_current_raw_objective(current_job) @@ -69,7 +68,8 @@ class OptunaObjective(object): if self.max_iteration_per_job and iteration_value[0] >= self.max_iteration_per_job: current_job.abort() break - + if not self.optimizer.monitor_job(current_job): + break sleep(self.sleep_interval) # noinspection PyProtectedMember