mirror of
https://github.com/clearml/clearml
synced 2025-05-02 12:02:10 +00:00
Fix Optuna and HPBandster optimizers to ignore extra kwargs in constructor. Use OptimizerOptuna as default optimizer for hyper-parameter optimization example.
This commit is contained in:
parent
093477cb35
commit
7de064aaa0
@ -1,19 +1,31 @@
|
||||
import logging
|
||||
|
||||
from trains import Task
|
||||
from trains.automation import DiscreteParameterRange, HyperParameterOptimizer, RandomSearch, \
|
||||
UniformIntegerParameterRange
|
||||
from trains.automation import (
|
||||
DiscreteParameterRange, HyperParameterOptimizer, RandomSearch,
|
||||
UniformIntegerParameterRange)
|
||||
|
||||
try:
|
||||
from trains.automation.hpbandster import OptimizerBOHB
|
||||
Our_SearchStrategy = OptimizerBOHB
|
||||
except ValueError:
|
||||
aSearchStrategy = None
|
||||
|
||||
if not aSearchStrategy:
|
||||
try:
|
||||
from trains.automation.optuna import OptimizerOptuna
|
||||
aSearchStrategy = OptimizerOptuna
|
||||
except ImportError as ex:
|
||||
pass
|
||||
|
||||
if not aSearchStrategy:
|
||||
try:
|
||||
from trains.automation.hpbandster import OptimizerBOHB
|
||||
aSearchStrategy = OptimizerBOHB
|
||||
except ImportError as ex:
|
||||
pass
|
||||
|
||||
if not aSearchStrategy:
|
||||
logging.getLogger().warning(
|
||||
'Apologies, it seems you do not have \'hpbandster\' installed, '
|
||||
'we will be using RandomSearch strategy instead\n'
|
||||
'If you like to try ' '{{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},\n'
|
||||
'run: pip install hpbandster')
|
||||
Our_SearchStrategy = RandomSearch
|
||||
'Apologies, it seems you do not have \'optuna\' or \'hpbandster\' installed, '
|
||||
'we will be using RandomSearch strategy instead')
|
||||
aSearchStrategy = RandomSearch
|
||||
|
||||
|
||||
def job_complete_callback(
|
||||
@ -69,7 +81,7 @@ an_optimizer = HyperParameterOptimizer(
|
||||
# this is the optimizer class (actually doing the optimization)
|
||||
# Currently, we can choose from GridSearch, RandomSearch or OptimizerBOHB (Bayesian optimization Hyper-Band)
|
||||
# more are coming soon...
|
||||
optimizer_class=Our_SearchStrategy,
|
||||
optimizer_class=aSearchStrategy,
|
||||
# Select an execution queue to schedule the experiments for execution
|
||||
execution_queue='1xGPU',
|
||||
# Optional: Limit the execution time of a single experiment, in minutes.
|
||||
|
@ -20,8 +20,8 @@ try:
|
||||
|
||||
Task.add_requirements('hpbandster')
|
||||
except ImportError:
|
||||
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
|
||||
"install with: pip install hpbandster")
|
||||
raise ImportError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
|
||||
"install with: pip install hpbandster")
|
||||
|
||||
|
||||
class _TrainsBandsterWorker(Worker):
|
||||
@ -123,7 +123,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
pool_period_min=2., # type: float
|
||||
time_limit_per_job=None, # type: Optional[float]
|
||||
local_port=9090, # type: int
|
||||
**bohb_kwargs, # type: Any
|
||||
**bohb_kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@ -181,7 +181,9 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs)
|
||||
self._max_iteration_per_job = max_iteration_per_job
|
||||
self._min_iteration_per_job = min_iteration_per_job
|
||||
self._bohb_kwargs = bohb_kwargs or {}
|
||||
verified_bohb_kwargs = ['eta', 'min_budget', 'max_budget', 'min_points_in_model', 'top_n_percent',
|
||||
'num_samples', 'random_fraction', 'bandwidth_factor', 'min_bandwidth']
|
||||
self._bohb_kwargs = dict((k, v) for k, v in bohb_kwargs.items() if k in verified_bohb_kwargs)
|
||||
self._param_iterator = None
|
||||
self._namespace = None
|
||||
self._bohb = None
|
||||
|
@ -11,8 +11,8 @@ try:
|
||||
import optuna
|
||||
Task.add_requirements('optuna')
|
||||
except ImportError:
|
||||
raise ValueError("OptimizerOptuna requires 'optuna' package, it was not found\n"
|
||||
"install with: pip install optuna")
|
||||
raise ImportError("OptimizerOptuna requires 'optuna' package, it was not found\n"
|
||||
"install with: pip install optuna")
|
||||
|
||||
|
||||
class OptunaObjective(object):
|
||||
@ -92,7 +92,7 @@ class OptimizerOptuna(SearchStrategy):
|
||||
optuna_sampler=None, # type: Optional[optuna.samplers.base]
|
||||
optuna_pruner=None, # type: Optional[optuna.pruners.base]
|
||||
continue_previous_study=None, # type: Optional[optuna.Study]
|
||||
**optuna_kwargs, # type: Any
|
||||
**optuna_kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@ -126,7 +126,8 @@ class OptimizerOptuna(SearchStrategy):
|
||||
max_iteration_per_job=max_iteration_per_job, total_max_jobs=total_max_jobs)
|
||||
self._optuna_sampler = optuna_sampler
|
||||
self._optuna_pruner = optuna_pruner
|
||||
self._optuna_kwargs = optuna_kwargs or {}
|
||||
verified_optuna_kwargs = []
|
||||
self._optuna_kwargs = dict((k, v) for k, v in optuna_kwargs.items() if k in verified_optuna_kwargs)
|
||||
self._param_iterator = None
|
||||
self._objective = None
|
||||
self._study = continue_previous_study if continue_previous_study else None
|
||||
|
Loading…
Reference in New Issue
Block a user