Add local HPO cli execution (--local --args)

This commit is contained in:
allegroai 2023-03-27 13:40:58 +03:00
parent fa92c75ffc
commit 4a91843559

View File

@ -60,7 +60,7 @@ def setup_parser(parser):
default=[], default=[],
help="List of parameters to search optimal value of. Each parameter must be a JSON having the following format:\n" help="List of parameters to search optimal value of. Each parameter must be a JSON having the following format:\n"
"{\n" "{\n"
' "name": str, # Name of the paramter you want to optimize\n' ' "name": str, # Name of the parameter you want to optimize\n'
' "type": Union["LogUniformParameterRange", "UniformParameterRange", "UniformIntegerParameterRange", "DiscreteParameterRange"],\n' ' "type": Union["LogUniformParameterRange", "UniformParameterRange", "UniformIntegerParameterRange", "DiscreteParameterRange"],\n'
" # other fields depending on type\n" " # other fields depending on type\n"
"}\n" "}\n"
@ -93,7 +93,7 @@ def setup_parser(parser):
"Each parameter must be a JSON having the following format:\n" "Each parameter must be a JSON having the following format:\n"
"{\n" "{\n"
' "name": str, # name of the parameter to override\n' ' "name": str, # name of the parameter to override\n'
' "value": Any # value of the paramter being overriden\n' ' "value": Any # value of the parameter being overriden\n'
"}", "}",
) )
parser.add_argument( parser.add_argument(
@ -138,7 +138,8 @@ def setup_parser(parser):
help="The maximum compute time in minutes. When time limit is exceeded, all jobs aborted", help="The maximum compute time in minutes. When time limit is exceeded, all jobs aborted",
) )
parser.add_argument( parser.add_argument(
"--pool-period-min", type=float, default=None, help="The time between two consecutive pools (minutes)" "--pool-period-min", type=float, default=0.2,
help="The time between two consecutive pools (minutes) default 0.2 min"
) )
parser.add_argument( parser.add_argument(
"--total-max-jobs", "--total-max-jobs",
@ -173,6 +174,22 @@ def setup_parser(parser):
help="Maximum execution time per single job in minutes. When time limit is exceeded job is aborted." help="Maximum execution time per single job in minutes. When time limit is exceeded job is aborted."
" Default: no time limit", " Default: no time limit",
) )
parser.add_argument(
"--max-number-of-concurrent-tasks",
type=int,
default=None,
help="The maximum number of concurrent Tasks (experiments) running at the same time.",
)
parser.add_argument(
'--args', default=None, nargs='*',
help='Arguments to pass to the remote execution, list of <argument>=<value> strings.'
'Currently only argparse/click/hydra/fire arguments are supported. '
'Example: --args lr=0.003 batch_size=64')
parser.add_argument(
"--local", action='store_true', default=False,
help="If set, run the experiments locally, Notice no new python environment will be created, "
"--script must point to a local file entrypoint and all arguments must be passed with --args",
)
def eval_params_search(params_search, params_override): def eval_params_search(params_search, params_override):
@ -239,6 +256,7 @@ def build_opt_kwargs(args):
"total_max_jobs", "total_max_jobs",
"min_iteration_per_job", "min_iteration_per_job",
"max_iteration_per_job", "max_iteration_per_job",
"max_number_of_concurrent_tasks"
] ]
for arg_name in optional_arg_names: for arg_name in optional_arg_names:
arg_val = getattr(args, arg_name) arg_val = getattr(args, arg_name)
@ -268,24 +286,35 @@ def cli():
) )
task_id = args.task_id task_id = args.task_id
if not task_id: if not task_id:
if args.queue is None:
print("No queue supplied to run the script from")
exit(1)
create_populate = CreateAndPopulate(script=args.script) create_populate = CreateAndPopulate(script=args.script)
create_populate.update_task_args(args.args)
print('Creating new task')
create_populate.create_task() create_populate.create_task()
Task.enqueue(create_populate.task, queue_name=args.queue) # update Task args
create_populate.update_task_args(args.args)
task_id = create_populate.get_id() task_id = create_populate.get_id()
optimizer = HyperParameterOptimizer( optimizer = HyperParameterOptimizer(
base_task_id=task_id, base_task_id=task_id,
execution_queue=args.queue,
hyper_parameters=eval_params_search(args.params_search, args.params_override), hyper_parameters=eval_params_search(args.params_search, args.params_override),
objective_metric_title=args.objective_metric_title, objective_metric_title=args.objective_metric_title,
objective_metric_series=args.objective_metric_series, objective_metric_series=args.objective_metric_series,
objective_metric_sign=args.objective_metric_sign, objective_metric_sign=args.objective_metric_sign,
optimizer_class=eval_optimizer_class(args.optimizer_class), optimizer_class=eval_optimizer_class(args.optimizer_class),
save_top_k_tasks_only=args.save_top_k_tasks_only, save_top_k_tasks_only=args.save_top_k_tasks_only,
**build_opt_kwargs(args), **build_opt_kwargs(args)
) )
optimizer.start()
# make sure we sample every 30sec
optimizer.set_report_period(0.5)
print("Starting HPO process:")
if args.local:
optimizer.start_locally()
else:
optimizer.start()
optimizer.wait() optimizer.wait()
print("Optimization completed!") print("Optimization completed!")
top_experiments_cnt = 10 top_experiments_cnt = 10