Fix type-hints and docstrings

This commit is contained in:
allegroai 2020-06-01 11:00:42 +03:00
parent 43dac458df
commit 2066d9ff9d
4 changed files with 17 additions and 11 deletions

View File

@ -329,8 +329,6 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
print('Best found configuration:', id2config[incumbent]['config'])
print('A total of {} unique configurations where sampled.'.format(len(id2config.keys())))
print('A total of {} runs where executed.'.format(len(self._res.get_all_runs())))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('The run took {:.1f} seconds to complete.'.format(

View File

@ -33,7 +33,7 @@ class Objective(object):
"""
def __init__(self, title, series, order='max', extremum=False):
# type: (str, str, Union['max', 'min'], bool) -> ()
# type: (str, str, str, bool) -> ()
"""
Construct objective object that will return the scalar value for a specific task ID
@ -190,7 +190,9 @@ class Budget(object):
self.compute_time = self.Field(compute_time_limit)
def to_dict(self):
# type: () -> (Mapping[Union['jobs', 'iterations', 'compute_time'], Mapping[Union['limit', 'used'], float]])
# type: () -> (Mapping[str, Mapping[str, float]])
# returned dict is Mapping[Union['jobs', 'iterations', 'compute_time'], Mapping[Union['limit', 'used'], float]]
current_budget = {}
jobs = self.jobs.used
if jobs:
@ -358,7 +360,7 @@ class SearchStrategy(object):
if return False, the job was aborted / completed, and should be taken off the current job list
If there is a budget limitation,
this call should update self.budget.time.update() / self.budget.iterations.update()
this call should update self.budget.compute_time.update() / self.budget.iterations.update()
:param TrainsJob job: a TrainsJob object to monitor
:return bool: If False, job is no longer relevant
@ -730,7 +732,7 @@ class HyperParameterOptimizer(object):
hyper_parameters, # type: Sequence[Parameter]
objective_metric_title, # type: str
objective_metric_series, # type: str
objective_metric_sign='min', # type: Union['min', 'max', 'min_global', 'max_global']
objective_metric_sign='min', # type: str
optimizer_class=RandomSearch, # type: type(SearchStrategy)
max_number_of_concurrent_tasks=10, # type: int
execution_queue='default', # type: str
@ -747,7 +749,8 @@ class HyperParameterOptimizer(object):
:param list hyper_parameters: list of Parameter objects to optimize over
:param str objective_metric_title: Objective metric title to maximize / minimize (example: 'validation')
:param str objective_metric_series: Objective metric series to maximize / minimize (example: 'loss')
:param str objective_metric_sign: Objective to maximize / minimize. Valid options:
:param str objective_metric_sign: Objective to maximize / minimize.
Valid options: ['min', 'max', 'min_global', 'max_global']
'min'/'max': Minimize/Maximize the last reported value for the specified title/series scalar
'min_global'/'max_global': Minimize/Maximize the min/max value
of *all* reported values for the specific title/series scalar

View File

@ -1054,7 +1054,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def get_reported_scalars(
self,
max_samples=0, # type: int
x_axis='iter' # type: Union['iter', 'timestamp', 'iso_time']
x_axis='iter' # type: str
):
# type: (...) -> Mapping[str, Mapping[str, Mapping[str, Sequence[float]]]]
"""
@ -1133,7 +1133,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
cls._force_requirements[package_name] = package_version
def _get_models(self, model_type='output'):
# type: (Union['output', 'input']) -> Sequence[Model]
# type: (str) -> Sequence[Model]
# model_type is either 'output' or 'input'
model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input'

View File

@ -57,7 +57,9 @@ class WeightsFileHandler(object):
@staticmethod
def add_pre_callback(callback_function):
# type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int
# type: (Callable[[str, str, str, Task], str]) -> int
# callback is Callable[[Union['load', 'save'], str, str, Task], str]
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
@ -70,7 +72,9 @@ class WeightsFileHandler(object):
@staticmethod
def add_post_callback(callback_function):
# type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int
# type: (Callable[[str, Model, str, str, str, Task], Model]) -> int
# callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]
if callback_function in WeightsFileHandler._model_post_callbacks.values():
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]