import six
import attr
from attr import validators


__all__ = ['range_validator', 'param', 'percent_param', 'TaskParameters']


def _canonize_validator(current_validator):
    """
    Convert current_validator to a new list and return it.

    If current_validator is None return an empty list.
    If current_validator is a list, return a copy of it.
    If current_validator is another type of  iterable, return a list version of it.
    If current_validator is a single value, return a one-list containing it.
    """

    if not current_validator:
        return []

    if isinstance(current_validator, (list, tuple)):
        current_validator = list(current_validator)
    else:
        current_validator = [current_validator]

    return current_validator


def range_validator(min_value, max_value):
    """
    A parameter validator that checks range constraint on a parameter.

    :param min_value: The minimum limit of the range, inclusive. None for no minimum limit.
    :param max_value: The maximum limit of the range, inclusive. None for no maximum limit.
    :return: A new range validator.
    """
    def _range_validator(instance, attribute, value):
        if ((min_value is not None) and (value < min_value)) or \
                ((max_value is not None) and (value > max_value)):
            raise ValueError("{} must be in range [{}, {}]".format(attribute.name, min_value, max_value))

    return _range_validator


def param(
        validator=None,
        range=None,
        type=None,
        desc=None,
        metadata=None,
        *args,
        **kwargs
):
    """
    A parameter inside a TaskParameters class.

    See TaskParameters for more information.

    :param validator: A validator or validators list.
        Any validator from attr.validators is applicable.

    :param range: The legal values range of the parameter.
        A tuple (min_limit, max_limit). None for no limitation.

    :param type: The type of the parameter.
        Supported types are int, str and float. None to place no limit of the type

    :param desc: A string description of the parameter, for future use.

    :param metadata: A dictionary metadata of the parameter, for future use.

    :param args: Additional arguments to pass to attr.attrib constructor.
    :param kwargs: Additional keyword arguments to pass to attr.attrib constructor.

    :return: An attr.attrib instance to use with TaskParameters class.

    Warning: Do not create an immutable param using args or kwargs. It will cause
    connect method of the TaskParameters class to fail.
    """

    metadata = metadata or {}
    metadata["desc"] = desc

    validator = _canonize_validator(validator)

    if type:
        validator.append(validators.optional(validators.instance_of(type)))

    if range:
        validator.append(range_validator(*range))

    return attr.ib(validator=validator, type=type, metadata=metadata, *args, **kwargs)


def percent_param(*args, **kwargs):
    """
    A param with type float and range limit (0, 1).
    """
    return param(range=(0, 1), type=float, *args, **kwargs)


class _AttrsMeta(type):
    def __new__(mcs, name, bases, dct):
        new_class = super(_AttrsMeta, mcs).__new__(mcs, name, bases, dct)
        return attr.s(new_class)


@six.add_metaclass(_AttrsMeta)
class TaskParameters(object):
    """
    Base class for task parameters.

    Inherit this class to create a parameter set to connect to a task.

    Usage Example:
    class MyParams(TaskParameters):
        iterations = param(
            type=int,
            desc="Number of iterations to run",
            range=(0, 100000),
        )

        target_accuracy = percent_param(
            desc="The target accuracy of the model",
        )
    """

    def to_dict(self):
        """
        :return: A new dictionary with keys are the parameters names and values
            are the corresponding values.
        """
        return attr.asdict(self)

    def update_from_dict(self, source_dict):
        """
        Update the parameters using values from a dictionary.

        :param source_dict: A dictionary with an entry for each parameter to
            update.
        """
        for key, value in source_dict.items():
            if not hasattr(self, key):
                raise ValueError("Unknown key {} in {} object".format(key, type(self).__name__))

            setattr(self, key, value)

    def connect(self, task):
        """
        Connect to a task.

        When running locally, the task will save the parameters from self.
        When running with a worker, self will be updated according to the task's
        saved parameters.

        :param task: The task to connect to.
        :type task: .Task
        """

        return task.connect(self)