Improve HyperParameterOptimizer

This commit is contained in:
allegroai 2020-08-27 15:05:21 +03:00
parent b25ca9b384
commit f11c6f5f27
3 changed files with 124 additions and 36 deletions

View File

@ -122,6 +122,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
total_max_jobs, # type: Optional[int] total_max_jobs, # type: Optional[int]
pool_period_min=2., # type: float pool_period_min=2., # type: float
time_limit_per_job=None, # type: Optional[float] time_limit_per_job=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
local_port=9090, # type: int local_port=9090, # type: int
**bohb_kwargs # type: Any **bohb_kwargs # type: Any
): ):
@ -163,6 +164,8 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
:param float pool_period_min: time in minutes between two consecutive pools :param float pool_period_min: time in minutes between two consecutive pools
:param float time_limit_per_job: Optional, maximum execution time per single job in minutes, :param float time_limit_per_job: Optional, maximum execution time per single job in minutes,
when time limit is exceeded job is aborted when time limit is exceeded job is aborted
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param int local_port: default port 9090 tcp, this is a must for the BOHB workers to communicate, even locally. :param int local_port: default port 9090 tcp, this is a must for the BOHB workers to communicate, even locally.
:param bohb_kwargs: arguments passed directly to the BOHB object :param bohb_kwargs: arguments passed directly to the BOHB object
""" """
@ -178,8 +181,8 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric, base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers, execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job, pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job,
max_iteration_per_job=max_iteration_per_job, min_iteration_per_job=min_iteration_per_job, compute_time_limit=compute_time_limit, max_iteration_per_job=max_iteration_per_job,
total_max_jobs=total_max_jobs) min_iteration_per_job=min_iteration_per_job, total_max_jobs=total_max_jobs)
self._max_iteration_per_job = max_iteration_per_job self._max_iteration_per_job = max_iteration_per_job
self._min_iteration_per_job = min_iteration_per_job self._min_iteration_per_job = min_iteration_per_job
verified_bohb_kwargs = ['eta', 'min_budget', 'max_budget', 'min_points_in_model', 'top_n_percent', verified_bohb_kwargs = ['eta', 'min_budget', 'max_budget', 'min_points_in_model', 'top_n_percent',

View File

@ -6,10 +6,12 @@ from itertools import product
from logging import getLogger from logging import getLogger
from threading import Thread, Event from threading import Thread, Event
from time import time from time import time
from typing import Union, Any, Sequence, Optional, Mapping, Callable from typing import Dict, Set, Tuple, Union, Any, Sequence, Optional, Mapping, Callable
from .job import TrainsJob from .job import TrainsJob
from .parameters import Parameter from .parameters import Parameter
from ..logger import Logger
from ..backend_api.services import workers as workers_service, tasks as tasks_services
from ..task import Task from ..task import Task
logger = getLogger('trains.automation.optimization') logger = getLogger('trains.automation.optimization')
@ -212,14 +214,11 @@ class Budget(object):
# returned dict is Mapping[Union['jobs', 'iterations', 'compute_time'], Mapping[Union['limit', 'used'], float]] # returned dict is Mapping[Union['jobs', 'iterations', 'compute_time'], Mapping[Union['limit', 'used'], float]]
current_budget = {} current_budget = {}
jobs = self.jobs.used jobs = self.jobs.used
if jobs: current_budget['jobs'] = {'limit': self.jobs.limit, 'used': jobs if jobs else 0}
current_budget['jobs'] = {'limit': self.jobs.limit, 'used': jobs}
iterations = self.iterations.used iterations = self.iterations.used
if iterations: current_budget['iterations'] = {'limit': self.iterations.limit, 'used': iterations if iterations else 0}
current_budget['iterations'] = {'limit': self.iterations.limit, 'used': iterations}
compute_time = self.compute_time.used compute_time = self.compute_time.used
if compute_time: current_budget['compute_time'] = {'limit': self.compute_time.limit, 'used': compute_time if compute_time else 0}
current_budget['compute_time'] = {'limit': self.compute_time.limit, 'used': compute_time}
return current_budget return current_budget
@ -239,6 +238,7 @@ class SearchStrategy(object):
num_concurrent_workers, # type: int num_concurrent_workers, # type: int
pool_period_min=2., # type: float pool_period_min=2., # type: float
time_limit_per_job=None, # type: Optional[float] time_limit_per_job=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
min_iteration_per_job=None, # type: Optional[int] min_iteration_per_job=None, # type: Optional[int]
max_iteration_per_job=None, # type: Optional[int] max_iteration_per_job=None, # type: Optional[int]
total_max_jobs=None, # type: Optional[int] total_max_jobs=None, # type: Optional[int]
@ -256,6 +256,8 @@ class SearchStrategy(object):
:param float pool_period_min: The time between two consecutive pools (minutes). :param float pool_period_min: The time between two consecutive pools (minutes).
:param float time_limit_per_job: The maximum execution time per single job in minutes. When time limit is :param float time_limit_per_job: The maximum execution time per single job in minutes. When time limit is
exceeded, the job is aborted. (Optional) exceeded, the job is aborted. (Optional)
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param int min_iteration_per_job: The minimum iterations (of the Objective metric) per single job (Optional) :param int min_iteration_per_job: The minimum iterations (of the Objective metric) per single job (Optional)
:param int max_iteration_per_job: The maximum iterations (of the Objective metric) per single job. :param int max_iteration_per_job: The maximum iterations (of the Objective metric) per single job.
When maximum iterations is exceeded, the job is aborted. (Optional) When maximum iterations is exceeded, the job is aborted. (Optional)
@ -270,6 +272,7 @@ class SearchStrategy(object):
self._num_concurrent_workers = num_concurrent_workers self._num_concurrent_workers = num_concurrent_workers
self.pool_period_minutes = pool_period_min self.pool_period_minutes = pool_period_min
self.time_limit_per_job = time_limit_per_job self.time_limit_per_job = time_limit_per_job
self.compute_time_limit = compute_time_limit
self.max_iteration_per_job = max_iteration_per_job self.max_iteration_per_job = max_iteration_per_job
self.min_iteration_per_job = min_iteration_per_job self.min_iteration_per_job = min_iteration_per_job
self.total_max_jobs = total_max_jobs self.total_max_jobs = total_max_jobs
@ -283,8 +286,7 @@ class SearchStrategy(object):
self._job_project = {} self._job_project = {}
self.budget = Budget( self.budget = Budget(
jobs_limit=self.total_max_jobs, jobs_limit=self.total_max_jobs,
compute_time_limit=self.total_max_jobs * self.time_limit_per_job if compute_time_limit=self.compute_time_limit if self.compute_time_limit else None,
self.time_limit_per_job and self.total_max_jobs else None,
iterations_limit=self.total_max_jobs * self.max_iteration_per_job if iterations_limit=self.total_max_jobs * self.max_iteration_per_job if
self.max_iteration_per_job and self.total_max_jobs else None self.max_iteration_per_job and self.total_max_jobs else None
) )
@ -402,6 +404,15 @@ class SearchStrategy(object):
if elapsed > self.time_limit_per_job: if elapsed > self.time_limit_per_job:
abort_job = True abort_job = True
if self.compute_time_limit:
if not self.time_limit_per_job:
elapsed = job.elapsed() / 60.
if elapsed > 0:
self.budget.compute_time.update(job.task_id(), elapsed)
self.budget.compute_time.update(job.task_id(), job.elapsed() / 60.)
if self.budget.compute_time.used and self.compute_time_limit < self.budget.compute_time.used:
abort_job = True
if self.max_iteration_per_job: if self.max_iteration_per_job:
iterations = self._get_job_iterations(job) iterations = self._get_job_iterations(job)
if iterations > 0: if iterations > 0:
@ -640,6 +651,7 @@ class GridSearch(SearchStrategy):
num_concurrent_workers, # type: int num_concurrent_workers, # type: int
pool_period_min=2., # type: float pool_period_min=2., # type: float
time_limit_per_job=None, # type: Optional[float] time_limit_per_job=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
max_iteration_per_job=None, # type: Optional[int] max_iteration_per_job=None, # type: Optional[int]
total_max_jobs=None, # type: Optional[int] total_max_jobs=None, # type: Optional[int]
**_ # type: Any **_ # type: Any
@ -656,6 +668,8 @@ class GridSearch(SearchStrategy):
:param float pool_period_min: The time between two consecutive pools (minutes). :param float pool_period_min: The time between two consecutive pools (minutes).
:param float time_limit_per_job: The maximum execution time per single job in minutes. When the time limit is :param float time_limit_per_job: The maximum execution time per single job in minutes. When the time limit is
exceeded job is aborted. (Optional) exceeded job is aborted. (Optional)
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param int max_iteration_per_job: The maximum iterations (of the Objective metric) :param int max_iteration_per_job: The maximum iterations (of the Objective metric)
per single job, When exceeded, the job is aborted. per single job, When exceeded, the job is aborted.
:param int total_max_jobs: The total maximum jobs for the optimization process. The default is ``None``, for :param int total_max_jobs: The total maximum jobs for the optimization process. The default is ``None``, for
@ -665,7 +679,8 @@ class GridSearch(SearchStrategy):
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric, base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers, execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job, pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job,
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs, **_) compute_time_limit=compute_time_limit, max_iteration_per_job=max_iteration_per_job,
total_max_jobs=total_max_jobs, **_)
self._param_iterator = None self._param_iterator = None
def create_job(self): def create_job(self):
@ -711,6 +726,7 @@ class RandomSearch(SearchStrategy):
num_concurrent_workers, # type: int num_concurrent_workers, # type: int
pool_period_min=2., # type: float pool_period_min=2., # type: float
time_limit_per_job=None, # type: Optional[float] time_limit_per_job=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
max_iteration_per_job=None, # type: Optional[int] max_iteration_per_job=None, # type: Optional[int]
total_max_jobs=None, # type: Optional[int] total_max_jobs=None, # type: Optional[int]
**_ # type: Any **_ # type: Any
@ -727,6 +743,8 @@ class RandomSearch(SearchStrategy):
:param float pool_period_min: The time between two consecutive pools (minutes). :param float pool_period_min: The time between two consecutive pools (minutes).
:param float time_limit_per_job: The maximum execution time per single job in minutes, :param float time_limit_per_job: The maximum execution time per single job in minutes,
when time limit is exceeded job is aborted. (Optional) when time limit is exceeded job is aborted. (Optional)
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param int max_iteration_per_job: The maximum iterations (of the Objective metric) :param int max_iteration_per_job: The maximum iterations (of the Objective metric)
per single job. When exceeded, the job is aborted. per single job. When exceeded, the job is aborted.
:param int total_max_jobs: The total maximum jobs for the optimization process. The default is ``None``, for :param int total_max_jobs: The total maximum jobs for the optimization process. The default is ``None``, for
@ -736,7 +754,8 @@ class RandomSearch(SearchStrategy):
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric, base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers, execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job, pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job,
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs, **_) compute_time_limit=compute_time_limit, max_iteration_per_job=max_iteration_per_job,
total_max_jobs=total_max_jobs, **_)
self._hyper_parameters_collection = set() self._hyper_parameters_collection = set()
def create_job(self): def create_job(self):
@ -787,6 +806,7 @@ class HyperParameterOptimizer(object):
max_number_of_concurrent_tasks=10, # type: int max_number_of_concurrent_tasks=10, # type: int
execution_queue='default', # type: str execution_queue='default', # type: str
optimization_time_limit=None, # type: Optional[float] optimization_time_limit=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
auto_connect_task=True, # type: bool auto_connect_task=True, # type: bool
always_create_task=False, # type: bool always_create_task=False, # type: bool
**optimizer_kwargs # type: Any **optimizer_kwargs # type: Any
@ -815,6 +835,8 @@ class HyperParameterOptimizer(object):
:param str execution_queue: The execution queue to use for launching Tasks (experiments). :param str execution_queue: The execution queue to use for launching Tasks (experiments).
:param float optimization_time_limit: The maximum time (minutes) for the entire optimization process. The :param float optimization_time_limit: The maximum time (minutes) for the entire optimization process. The
default is ``None``, indicating no time limit. default is ``None``, indicating no time limit.
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param bool auto_connect_task: Store optimization arguments and configuration in the Task? :param bool auto_connect_task: Store optimization arguments and configuration in the Task?
The values are: The values are:
@ -893,6 +915,7 @@ class HyperParameterOptimizer(object):
max_number_of_concurrent_tasks=max_number_of_concurrent_tasks, max_number_of_concurrent_tasks=max_number_of_concurrent_tasks,
execution_queue=execution_queue, execution_queue=execution_queue,
optimization_time_limit=optimization_time_limit, optimization_time_limit=optimization_time_limit,
compute_time_limit=compute_time_limit,
optimizer_kwargs=optimizer_kwargs) optimizer_kwargs=optimizer_kwargs)
# make sure all the created tasks are our children, as we are creating them # make sure all the created tasks are our children, as we are creating them
if self._task: if self._task:
@ -916,7 +939,8 @@ class HyperParameterOptimizer(object):
self.optimizer = optimizer_class( self.optimizer = optimizer_class(
base_task_id=opts['base_task_id'], hyper_parameters=hyper_parameters, base_task_id=opts['base_task_id'], hyper_parameters=hyper_parameters,
objective_metric=self.objective_metric, execution_queue=opts['execution_queue'], objective_metric=self.objective_metric, execution_queue=opts['execution_queue'],
num_concurrent_workers=opts['max_number_of_concurrent_tasks'], **opts.get('optimizer_kwargs', {})) num_concurrent_workers=opts['max_number_of_concurrent_tasks'],
compute_time_limit=opts['compute_time_limit'], **opts.get('optimizer_kwargs', {}))
self.optimizer.set_optimizer_task(self._task) self.optimizer.set_optimizer_task(self._task)
self.optimization_timeout = None self.optimization_timeout = None
self.optimization_start_time = None self.optimization_start_time = None
@ -1227,10 +1251,8 @@ class HyperParameterOptimizer(object):
def _report_daemon(self): def _report_daemon(self):
# type: () -> () # type: () -> ()
worker_to_series = {}
title, series = self.objective_metric.get_objective_metric() title, series = self.objective_metric.get_objective_metric()
title = '{}/{}'.format(title, series) title = '{}/{}'.format(title, series)
series = 'machine:'
counter = 0 counter = 0
completed_jobs = dict() completed_jobs = dict()
best_experiment = float('-inf'), None best_experiment = float('-inf'), None
@ -1256,20 +1278,6 @@ class HyperParameterOptimizer(object):
# do some reporting # do some reporting
# running objective, per machine
running_job_ids = set()
for j in self.optimizer.get_running_jobs():
worker = j.worker()
running_job_ids.add(j.task_id())
if worker not in worker_to_series:
worker_to_series[worker] = len(worker_to_series) + 1
machine_id = worker_to_series[worker]
value = self.objective_metric.get_objective(j)
if value is not None:
task_logger.report_scalar(
title=title, series='{}{}'.format(series, machine_id),
iteration=counter, value=value)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
budget = self.optimizer.budget.to_dict() budget = self.optimizer.budget.to_dict()
@ -1289,8 +1297,10 @@ class HyperParameterOptimizer(object):
(self.optimization_timeout - self.optimization_start_time)), ndigits=1) (self.optimization_timeout - self.optimization_start_time)), ndigits=1)
) )
self._report_resources(task_logger, counter)
# collect a summary of all the jobs and their final objective values # collect a summary of all the jobs and their final objective values
cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - running_job_ids cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - \
{j.task_id() for j in self.optimizer.get_running_jobs()}
if cur_completed_jobs != set(completed_jobs.keys()): if cur_completed_jobs != set(completed_jobs.keys()):
pairs = [] pairs = []
labels = [] labels = []
@ -1315,6 +1325,8 @@ class HyperParameterOptimizer(object):
c = completed_jobs[job_id] c = completed_jobs[job_id]
self._experiment_completed_cb(job_id, c[0], c[1], c[2], best_experiment[1]) self._experiment_completed_cb(job_id, c[0], c[1], c[2], best_experiment[1])
self._report_completed_tasks_best_results(completed_jobs, task_logger, title, counter)
if pairs: if pairs:
print('Updating job performance summary plot/table') print('Updating job performance summary plot/table')
@ -1345,3 +1357,74 @@ class HyperParameterOptimizer(object):
# we should leave # we should leave
self.stop() self.stop()
return return
def _report_completed_tasks_best_results(self, completed_jobs, task_logger, title, counter):
# type: (Dict[str, Tuple[float, int, Dict[str, int]]], Logger, str, int) -> ()
if completed_jobs:
value_func, series_name = (max, "max") if self.objective_metric.get_objective_sign() > 0 else \
(min, "min")
task_logger.report_scalar(
title=title,
series=series_name,
iteration=counter,
value=value_func([val[0] for val in completed_jobs.values()]))
latest_completed = self._get_latest_completed_task_value(set(completed_jobs.keys()))
if latest_completed:
task_logger.report_scalar(
title=title,
series="last reported",
iteration=counter,
value=latest_completed)
def _report_resources(self, task_logger, iteration):
# type: (Logger, int) -> ()
self._report_active_workers(task_logger, iteration)
self._report_tasks_status(task_logger, iteration)
def _report_active_workers(self, task_logger, iteration):
# type: (Logger, int) -> ()
cur_task = self._task or Task.current_task()
res = cur_task.send(workers_service.GetAllRequest())
response = res.wait()
if response.ok():
all_workers = response
queue_workers = len(
[
worker.get("id")
for worker in all_workers.response_data.get("workers")
for q in worker.get("queues")
if q.get("name") == self.execution_queue
]
)
task_logger.report_scalar(title="resources", series="queue workers", iteration=iteration, value=queue_workers)
def _report_tasks_status(self, task_logger, iteration):
# type: (Logger, int) -> ()
tasks_status = {"running tasks": 0, "pending tasks": 0}
for job in self.optimizer.get_running_jobs():
if job.is_running():
tasks_status["running tasks"] += 1
else:
tasks_status["pending tasks"] += 1
for series, val in tasks_status.items():
task_logger.report_scalar(
title="resources", series=series,
iteration=iteration, value=val)
def _get_latest_completed_task_value(self, cur_completed_jobs):
# type: (Set[str]) -> float
completed_value = None
latest_completed = None
cur_task = self._task or Task.current_task()
for j in cur_completed_jobs:
res = cur_task.send(tasks_services.GetByIdRequest(task=j))
response = res.wait()
if not response.ok() or response.response_data["task"].get("status") != Task.TaskStatusEnum.completed:
continue
completed_time = datetime.strptime(response.response_data["task"]["completed"].partition("+")[0],
"%Y-%m-%dT%H:%M:%S.%f")
completed_time = completed_time.timestamp()
if not latest_completed or completed_time > latest_completed:
latest_completed = completed_time
completed_value = self.objective_metric.get_objective(j)
return completed_value

View File

@ -45,11 +45,10 @@ class OptunaObjective(object):
current_job.launch(self.queue_name) current_job.launch(self.queue_name)
iteration_value = None iteration_value = None
is_pending = True is_pending = True
while self.optimizer.monitor_job(current_job):
while not current_job.is_stopped():
if is_pending and not current_job.is_pending(): if is_pending and not current_job.is_pending():
is_pending = False is_pending = False
self.optimizer.budget.jobs.update(current_job.task_id(), 1.)
if not is_pending: if not is_pending:
# noinspection PyProtectedMember # noinspection PyProtectedMember
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(current_job) iteration_value = self.optimizer._objective_metric.get_current_raw_objective(current_job)
@ -93,6 +92,7 @@ class OptimizerOptuna(SearchStrategy):
pool_period_min=2., # type: float pool_period_min=2., # type: float
min_iteration_per_job=None, # type: Optional[int] min_iteration_per_job=None, # type: Optional[int]
time_limit_per_job=None, # type: Optional[float] time_limit_per_job=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
optuna_sampler=None, # type: Optional[optuna.samplers.base] optuna_sampler=None, # type: Optional[optuna.samplers.base]
optuna_pruner=None, # type: Optional[optuna.pruners.base] optuna_pruner=None, # type: Optional[optuna.pruners.base]
continue_previous_study=None, # type: Optional[optuna.Study] continue_previous_study=None, # type: Optional[optuna.Study]
@ -123,14 +123,16 @@ class OptimizerOptuna(SearchStrategy):
before early stopping the Job. (Optional) before early stopping the Job. (Optional)
:param float time_limit_per_job: Optional, maximum execution time per single job in minutes, :param float time_limit_per_job: Optional, maximum execution time per single job in minutes,
when time limit is exceeded job is aborted when time limit is exceeded job is aborted
:param float compute_time_limit: The maximum compute time in minutes. When time limit is exceeded,
all jobs aborted. (Optional)
:param optuna_kwargs: arguments passed directly to the Optuna object :param optuna_kwargs: arguments passed directly to the Optuna object
""" """
super(OptimizerOptuna, self).__init__( super(OptimizerOptuna, self).__init__(
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric, base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers, execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job, pool_period_min=pool_period_min, time_limit_per_job=time_limit_per_job,
max_iteration_per_job=max_iteration_per_job, min_iteration_per_job=min_iteration_per_job, compute_time_limit=compute_time_limit, max_iteration_per_job=max_iteration_per_job,
total_max_jobs=total_max_jobs) min_iteration_per_job=min_iteration_per_job, total_max_jobs=total_max_jobs)
self._optuna_sampler = optuna_sampler self._optuna_sampler = optuna_sampler
self._optuna_pruner = optuna_pruner self._optuna_pruner = optuna_pruner
verified_optuna_kwargs = [] verified_optuna_kwargs = []