mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
164 lines
4.7 KiB
Python
164 lines
4.7 KiB
Python
import six
|
|
import attr
|
|
from attr import validators
|
|
|
|
|
|
__all__ = ['range_validator', 'param', 'percent_param', 'TaskParameters']
|
|
|
|
|
|
def _canonize_validator(current_validator):
|
|
"""
|
|
Convert current_validator to a new list and return it.
|
|
|
|
If current_validator is None return an empty list.
|
|
If current_validator is a list, return a copy of it.
|
|
If current_validator is another type of iterable, return a list version of it.
|
|
If current_validator is a single value, return a one-list containing it.
|
|
"""
|
|
|
|
if not current_validator:
|
|
return []
|
|
|
|
if isinstance(current_validator, (list, tuple)):
|
|
current_validator = list(current_validator)
|
|
else:
|
|
current_validator = [current_validator]
|
|
|
|
return current_validator
|
|
|
|
|
|
def range_validator(min_value, max_value):
|
|
"""
|
|
A parameter validator that checks range constraint on a parameter.
|
|
|
|
:param min_value: The minimum limit of the range, inclusive. None for no minimum limit.
|
|
:param max_value: The maximum limit of the range, inclusive. None for no maximum limit.
|
|
:return: A new range validator
|
|
"""
|
|
def _range_validator(instance, attribute, value):
|
|
if ((min_value is not None) and (value < min_value)) or \
|
|
((max_value is not None) and (value > max_value)):
|
|
raise ValueError("{} must be in range [{}, {}]".format(attribute.name, min_value, max_value))
|
|
|
|
return _range_validator
|
|
|
|
|
|
def param(
|
|
validator=None,
|
|
range=None,
|
|
type=None,
|
|
desc=None,
|
|
metadata=None,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
"""
|
|
A parameter inside a TaskParameters class.
|
|
|
|
See TaskParameters for more information.
|
|
|
|
:param validator: A validator or validators list.
|
|
Any validator from attr.validators is applicable.
|
|
|
|
:param range: The legal values range of the parameter.
|
|
A tuple (min_limit, max_limit). None for no limitation.
|
|
|
|
:param type: The type of the parameter.
|
|
Supported types are int, str and float. None to place no limit of the type
|
|
|
|
:param desc: A string description of the parameter, for future use.
|
|
|
|
:param metadata: A dictionary metadata of the parameter, for future use.
|
|
|
|
:param args: Additional arguments to pass to attr.attrib constructor.
|
|
:param kwargs: Additional keyword arguments to pass to attr.attrib constructor.
|
|
|
|
:return: An attr.attrib instance to use with TaskParameters class.
|
|
|
|
Warning: Do not create an immutable param using args or kwargs. It will cause
|
|
connect method of the TaskParameters class to fail.
|
|
"""
|
|
|
|
metadata = metadata or {}
|
|
metadata["desc"] = desc
|
|
|
|
validator = _canonize_validator(validator)
|
|
|
|
if type:
|
|
validator.append(validators.optional(validators.instance_of(type)))
|
|
|
|
if range:
|
|
validator.append(range_validator(*range))
|
|
|
|
return attr.ib(validator=validator, type=type, metadata=metadata, *args, **kwargs)
|
|
|
|
|
|
def percent_param(*args, **kwargs):
|
|
"""
|
|
A param with type float and range limit (0, 1).
|
|
"""
|
|
return param(range=(0, 1), type=float, *args, **kwargs)
|
|
|
|
|
|
class _AttrsMeta(type):
|
|
def __new__(mcs, name, bases, dct):
|
|
new_class = super(_AttrsMeta, mcs).__new__(mcs, name, bases, dct)
|
|
return attr.s(new_class)
|
|
|
|
|
|
@six.add_metaclass(_AttrsMeta)
|
|
class TaskParameters(object):
|
|
"""
|
|
Base class for task parameters.
|
|
|
|
Inherit this class to create a parameter set to connect to a task.
|
|
|
|
Usage Example:
|
|
class MyParams(TaskParameters):
|
|
iterations = param(
|
|
type=int,
|
|
desc="Number of iterations to run",
|
|
range=(0, 100000),
|
|
)
|
|
|
|
target_accuracy = percent_param(
|
|
desc="The target accuracy of the model",
|
|
)
|
|
"""
|
|
|
|
def to_dict(self):
|
|
"""
|
|
:return: A new dictionary with keys are the parameters names and values
|
|
are the corresponding values.
|
|
"""
|
|
return attr.asdict(self)
|
|
|
|
def update_from_dict(self, source_dict):
|
|
"""
|
|
Update the parameters using values from a dictionary.
|
|
|
|
:param source_dict: A dictionary with an entry for each parameter to
|
|
update.
|
|
"""
|
|
for key, value in source_dict.items():
|
|
if not hasattr(self, key):
|
|
raise ValueError("Unknown key {} in {} object".format(key, type(self).__name__))
|
|
|
|
setattr(self, key, value)
|
|
|
|
def connect(self, task):
|
|
"""
|
|
Connect to a task.
|
|
|
|
When running locally, the task will save the parameters from self.
|
|
When running with a worker, self will be updated according to the task's
|
|
saved parameters.
|
|
|
|
:param task: The task to connect to.
|
|
:type task: .Task
|
|
"""
|
|
|
|
return task.connect(self)
|
|
|
|
|