Fix Optuna and HPBandster optimizers to ignore extra kwargs in constructor. Use OptimizerOptuna as default optimizer for hyper-parameter optimization example.

This commit is contained in:
allegroai 2020-07-30 14:56:15 +03:00
parent 093477cb35
commit 7de064aaa0
3 changed files with 35 additions and 20 deletions

View File

@ -1,19 +1,31 @@
import logging
from trains import Task
from trains.automation import DiscreteParameterRange, HyperParameterOptimizer, RandomSearch, \
UniformIntegerParameterRange
from trains.automation import (
DiscreteParameterRange, HyperParameterOptimizer, RandomSearch,
UniformIntegerParameterRange)
try:
from trains.automation.hpbandster import OptimizerBOHB
Our_SearchStrategy = OptimizerBOHB
except ValueError:
aSearchStrategy = None
if not aSearchStrategy:
try:
from trains.automation.optuna import OptimizerOptuna
aSearchStrategy = OptimizerOptuna
except ImportError as ex:
pass
if not aSearchStrategy:
try:
from trains.automation.hpbandster import OptimizerBOHB
aSearchStrategy = OptimizerBOHB
except ImportError as ex:
pass
if not aSearchStrategy:
logging.getLogger().warning(
'Apologies, it seems you do not have \'hpbandster\' installed, '
'we will be using RandomSearch strategy instead\n'
'If you like to try ' '{{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},\n'
'run: pip install hpbandster')
Our_SearchStrategy = RandomSearch
'Apologies, it seems you do not have \'optuna\' or \'hpbandster\' installed, '
'we will be using RandomSearch strategy instead')
aSearchStrategy = RandomSearch
def job_complete_callback(
@ -69,7 +81,7 @@ an_optimizer = HyperParameterOptimizer(
# this is the optimizer class (actually doing the optimization)
# Currently, we can choose from GridSearch, RandomSearch or OptimizerBOHB (Bayesian optimization Hyper-Band)
# more are coming soon...
optimizer_class=Our_SearchStrategy,
optimizer_class=aSearchStrategy,
# Select an execution queue to schedule the experiments for execution
execution_queue='1xGPU',
# Optional: Limit the execution time of a single experiment, in minutes.

View File

@ -20,8 +20,8 @@ try:
Task.add_requirements('hpbandster')
except ImportError:
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
"install with: pip install hpbandster")
raise ImportError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
"install with: pip install hpbandster")
class _TrainsBandsterWorker(Worker):
@ -123,7 +123,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
pool_period_min=2., # type: float
time_limit_per_job=None, # type: Optional[float]
local_port=9090, # type: int
**bohb_kwargs, # type: Any
**bohb_kwargs # type: Any
):
# type: (...) -> None
"""
@ -181,7 +181,9 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs)
self._max_iteration_per_job = max_iteration_per_job
self._min_iteration_per_job = min_iteration_per_job
self._bohb_kwargs = bohb_kwargs or {}
verified_bohb_kwargs = ['eta', 'min_budget', 'max_budget', 'min_points_in_model', 'top_n_percent',
'num_samples', 'random_fraction', 'bandwidth_factor', 'min_bandwidth']
self._bohb_kwargs = dict((k, v) for k, v in bohb_kwargs.items() if k in verified_bohb_kwargs)
self._param_iterator = None
self._namespace = None
self._bohb = None

View File

@ -11,8 +11,8 @@ try:
import optuna
Task.add_requirements('optuna')
except ImportError:
raise ValueError("OptimizerOptuna requires 'optuna' package, it was not found\n"
"install with: pip install optuna")
raise ImportError("OptimizerOptuna requires 'optuna' package, it was not found\n"
"install with: pip install optuna")
class OptunaObjective(object):
@ -92,7 +92,7 @@ class OptimizerOptuna(SearchStrategy):
optuna_sampler=None, # type: Optional[optuna.samplers.base]
optuna_pruner=None, # type: Optional[optuna.pruners.base]
continue_previous_study=None, # type: Optional[optuna.Study]
**optuna_kwargs, # type: Any
**optuna_kwargs # type: Any
):
# type: (...) -> None
"""
@ -126,7 +126,8 @@ class OptimizerOptuna(SearchStrategy):
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs)
self._optuna_sampler = optuna_sampler
self._optuna_pruner = optuna_pruner
self._optuna_kwargs = optuna_kwargs or {}
verified_optuna_kwargs = []
self._optuna_kwargs = dict((k, v) for k, v in optuna_kwargs.items() if k in verified_optuna_kwargs)
self._param_iterator = None
self._objective = None
self._study = continue_previous_study if continue_previous_study else None