mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Fix Optuna HPO parameter serializing (issue #254)
This commit is contained in:
parent
f11a36f3c3
commit
28d7527537
@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import six
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
@ -1229,6 +1230,18 @@ class HyperParameterOptimizer(object):
|
||||
configuration_dict = {'parameter_optimization_space': [
|
||||
Parameter.from_dict(c) for c in configuration_dict['parameter_optimization_space']]}
|
||||
|
||||
complex_optimizer_kwargs = None
|
||||
if 'optimizer_kwargs' in kwargs:
|
||||
# do not store complex optimizer kwargs:
|
||||
optimizer_kwargs = kwargs.pop('optimizer_kwargs', {})
|
||||
complex_optimizer_kwargs = {
|
||||
k: v for k, v in optimizer_kwargs.items()
|
||||
if not isinstance(v, six.string_types + six.integer_types +
|
||||
(six.text_type, float, list, tuple, dict, type(None)))}
|
||||
kwargs['optimizer_kwargs'] = {
|
||||
k: v for k, v in optimizer_kwargs.items() if k not in complex_optimizer_kwargs}
|
||||
|
||||
# skip non basic types:
|
||||
arguments = {'opt': kwargs}
|
||||
if type(optimizer_class) != type:
|
||||
logger.warning('Auto Connect optimizer_class disabled, {} is already instantiated'.format(optimizer_class))
|
||||
@ -1255,6 +1268,12 @@ class HyperParameterOptimizer(object):
|
||||
optimizer_class, original_class))
|
||||
optimizer_class = original_class
|
||||
|
||||
if complex_optimizer_kwargs:
|
||||
if 'optimizer_kwargs' not in arguments['opt']:
|
||||
arguments['opt']['optimizer_kwargs'] = complex_optimizer_kwargs
|
||||
else:
|
||||
arguments['opt']['optimizer_kwargs'].update(complex_optimizer_kwargs)
|
||||
|
||||
return optimizer_class, configuration_dict['parameter_optimization_space'], arguments['opt']
|
||||
|
||||
def _daemon(self):
|
||||
@ -1335,11 +1354,11 @@ class HyperParameterOptimizer(object):
|
||||
completed_jobs[job_id] = (
|
||||
value,
|
||||
iteration_value[0] if iteration_value else -1,
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)})))
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)}))) # noqa
|
||||
elif completed_jobs.get(job_id):
|
||||
completed_jobs[job_id] = (completed_jobs[job_id][0],
|
||||
completed_jobs[job_id][1],
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)})))
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)}))) # noqa
|
||||
pairs.append((i, completed_jobs[job_id][0]))
|
||||
labels.append(str(completed_jobs[job_id][2])[1:-1])
|
||||
else:
|
||||
@ -1350,7 +1369,7 @@ class HyperParameterOptimizer(object):
|
||||
completed_jobs[job_id] = (
|
||||
value,
|
||||
iteration_value[0] if iteration_value else -1,
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)})))
|
||||
copy(dict(**params, **{"status": id_status.get(job_id)}))) # noqa
|
||||
# callback new experiment completed
|
||||
if self._experiment_completed_cb:
|
||||
normalized_value = self.objective_metric.get_normalized_objective(job_id)
|
||||
|
Loading…
Reference in New Issue
Block a user