""" absl-py FLAGS binding utility functions """
import six

from ..backend_interface.task.args import _Arguments
from ..config import running_remotely


class PatchAbsl(object):
    _original_DEFINE_flag = None
    _original_FLAGS_parse_call = None
    _task = None

    @classmethod
    def update_current_task(cls, current_task):
        cls._task = current_task
        cls._patch_absl()

    @classmethod
    def _patch_absl(cls):
        if cls._original_DEFINE_flag:
            return
        # noinspection PyBroadException
        try:
            from absl.flags import _defines
            if six.PY2:
                cls._original_DEFINE_flag = staticmethod(_defines.DEFINE_flag)
            else:
                cls._original_DEFINE_flag = _defines.DEFINE_flag
            _defines.DEFINE_flag = cls._patched_define_flag
        except Exception:
            # there is no absl
            pass

        try:
            from absl.flags._flagvalues import FlagValues
            if six.PY2:
                cls._original_FLAGS_parse_call = staticmethod(FlagValues.__call__)
            else:
                cls._original_FLAGS_parse_call = FlagValues.__call__
            FlagValues.__call__ = cls._patched_FLAGS_parse_call
        except Exception:
            # there is no absl
            pass

        if cls._original_DEFINE_flag:
            try:
                # if absl was already set, let's update our task params
                from absl import flags
                cls._update_current_flags(flags.FLAGS)
            except Exception:
                # there is no absl
                pass

    @staticmethod
    def _patched_define_flag(*args, **kwargs):
        if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
            if PatchAbsl._original_DEFINE_flag:
                return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
            else:
                return None
        # noinspection PyBroadException
        try:
            flag = args[0] if len(args) >= 1 else None
            module_name = args[2] if len(args) >= 3 else None
            param_name = None
            if flag:
                param_name = ((module_name + _Arguments._prefix_sep) if module_name else '') + flag.name
        except Exception:
            flag = None
            param_name = None

        if running_remotely():
            # noinspection PyBroadException
            try:
                if param_name and flag:
                    param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value},
                                                                         prefix=_Arguments._prefix_tf_defines)
                    flag.value = param_dict.get(param_name, flag.value)
            except Exception:
                pass
            ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
        else:
            if flag and param_name:
                value = flag.value
                PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value})
            ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
        return ret

    @staticmethod
    def _patched_FLAGS_parse_call(self, *args, **kwargs):
        ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
        # noinspection PyBroadException
        try:
            PatchAbsl._update_current_flags(self)
        except Exception:
            pass
        return ret

    @classmethod
    def _update_current_flags(cls, FLAGS):
        if not cls._task:
            return
        # noinspection PyBroadException
        try:
            if running_remotely():
                param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
                param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
                for k, v in param_dict.items():
                    # noinspection PyBroadException
                    try:
                        parts = k.split(_Arguments._prefix_sep)
                        k = parts[0]
                        if k in FLAGS:
                            FLAGS[k].value = v
                    except Exception:
                        pass
            else:
                # clear previous parameters
                parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
                cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines)
        except Exception:
            pass