clearml/trains/task_parameters.py
2020-06-11 21:15:40 +03:00

162 lines
4.7 KiB
Python

import six
import attr
from attr import validators
__all__ = ['range_validator', 'param', 'percent_param', 'TaskParameters']
def _canonize_validator(current_validator):
"""
Convert current_validator to a new list and return it.
If current_validator is None return an empty list.
If current_validator is a list, return a copy of it.
If current_validator is another type of iterable, return a list version of it.
If current_validator is a single value, return a one-list containing it.
"""
if not current_validator:
return []
if isinstance(current_validator, (list, tuple)):
current_validator = list(current_validator)
else:
current_validator = [current_validator]
return current_validator
def range_validator(min_value, max_value):
"""
A parameter validator that checks range constraint on a parameter.
:param min_value: The minimum limit of the range, inclusive. None for no minimum limit.
:param max_value: The maximum limit of the range, inclusive. None for no maximum limit.
:return: A new range validator.
"""
def _range_validator(instance, attribute, value):
if ((min_value is not None) and (value < min_value)) or \
((max_value is not None) and (value > max_value)):
raise ValueError("{} must be in range [{}, {}]".format(attribute.name, min_value, max_value))
return _range_validator
def param(
validator=None,
range=None,
type=None,
desc=None,
metadata=None,
*args,
**kwargs
):
"""
A parameter inside a TaskParameters class.
See TaskParameters for more information.
:param validator: A validator or validators list.
Any validator from attr.validators is applicable.
:param range: The legal values range of the parameter.
A tuple (min_limit, max_limit). None for no limitation.
:param type: The type of the parameter.
Supported types are int, str and float. None to place no limit of the type
:param desc: A string description of the parameter, for future use.
:param metadata: A dictionary metadata of the parameter, for future use.
:param args: Additional arguments to pass to attr.attrib constructor.
:param kwargs: Additional keyword arguments to pass to attr.attrib constructor.
:return: An attr.attrib instance to use with TaskParameters class.
Warning: Do not create an immutable param using args or kwargs. It will cause
connect method of the TaskParameters class to fail.
"""
metadata = metadata or {}
metadata["desc"] = desc
validator = _canonize_validator(validator)
if type:
validator.append(validators.optional(validators.instance_of(type)))
if range:
validator.append(range_validator(*range))
return attr.ib(validator=validator, type=type, metadata=metadata, *args, **kwargs)
def percent_param(*args, **kwargs):
"""
A param with type float and range limit (0, 1).
"""
return param(range=(0, 1), type=float, *args, **kwargs)
class _AttrsMeta(type):
def __new__(mcs, name, bases, dct):
new_class = super(_AttrsMeta, mcs).__new__(mcs, name, bases, dct)
return attr.s(new_class)
@six.add_metaclass(_AttrsMeta)
class TaskParameters(object):
"""
Base class for task parameters.
Inherit this class to create a parameter set to connect to a task.
Usage Example:
class MyParams(TaskParameters):
iterations = param(
type=int,
desc="Number of iterations to run",
range=(0, 100000),
)
target_accuracy = percent_param(
desc="The target accuracy of the model",
)
"""
def to_dict(self):
"""
:return: A new dictionary with keys are the parameters names and values
are the corresponding values.
"""
return attr.asdict(self)
def update_from_dict(self, source_dict):
"""
Update the parameters using values from a dictionary.
:param source_dict: A dictionary with an entry for each parameter to
update.
"""
for key, value in source_dict.items():
if not hasattr(self, key):
raise ValueError("Unknown key {} in {} object".format(key, type(self).__name__))
setattr(self, key, value)
def connect(self, task):
"""
Connect to a task.
When running locally, the task will save the parameters from self.
When running with a worker, self will be updated according to the task's
saved parameters.
:param task: The task to connect to.
:type task: .Task
"""
return task.connect(self)