import io
import sys

from ..config import running_remotely, get_remote_task_id


class PatchHydra(object):
    _original_run_job = None
    _current_task = None
    _last_untracked_state = {}
    _config_section = 'OmegaConf'
    _parameter_section = 'Hydra'
    _parameter_allow_full_edit = '_allow_omegaconf_edit_'

    @classmethod
    def patch_hydra(cls):
        # noinspection PyBroadException
        try:
            # only once
            if cls._original_run_job:
                return True
            # if hydra is not loaded, do not patch anything
            if not sys.modules.get('hydra'):
                return False

            from hydra._internal import hydra as internal_hydra  # noqa

            cls._original_run_job = internal_hydra.Hydra.run
            internal_hydra.Hydra.run = cls._patched_run_job
            return True
        except Exception:
            return False

    @staticmethod
    def update_current_task(task):
        # set current Task before patching
        PatchHydra._current_task = task
        if PatchHydra.patch_hydra():
            # check if we have an untracked state, store it.
            if PatchHydra._last_untracked_state.get('connect'):
                PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
            if PatchHydra._last_untracked_state.get('_set_configuration'):
                # noinspection PyProtectedMember
                PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
            PatchHydra._last_untracked_state = {}
        else:
            # if patching failed set it to None
            PatchHydra._current_task = None

    @staticmethod
    def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs):
        allow_omegaconf_edit = False

        def patched_task_function(a_config, *a_args, **a_kwargs):
            from omegaconf import OmegaConf  # noqa
            if not running_remotely() or not allow_omegaconf_edit:
                PatchHydra._register_omegaconf(a_config)
            else:
                # noinspection PyProtectedMember
                omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
                loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
                a_config = OmegaConf.merge(a_config, loaded_config)
                PatchHydra._register_omegaconf(a_config, is_read_only=False)
            return task_function(a_config, *a_args, **a_kwargs)

        # store the config
        # noinspection PyBroadException
        try:
            if running_remotely():
                if not PatchHydra._current_task:
                    from ..task import Task
                    PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
                # get the _parameter_allow_full_edit casted back to boolean
                connected_config = dict()
                connected_config[PatchHydra._parameter_allow_full_edit] = False
                PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
                allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
                # get all the overrides
                full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
                stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
                                 if k.startswith(PatchHydra._parameter_section+'/')}
                stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
                overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()]
            else:
                stored_config = dict(arg.split('=', 1) for arg in overrides)
                stored_config[PatchHydra._parameter_allow_full_edit] = False
                if PatchHydra._current_task:
                    PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
                    PatchHydra._last_untracked_state.pop('connect', None)
                else:
                    PatchHydra._last_untracked_state['connect'] = dict(
                        mutable=stored_config, name=PatchHydra._parameter_section)
                # todo: remove the overrides section from the Args (we have it here)
                # PatchHydra._current_task.delete_parameter('Args/overrides')
        except Exception:
            pass
        return PatchHydra._original_run_job(self, config_name, patched_task_function, overrides, *args, **kwargs)

    @staticmethod
    def _register_omegaconf(config, is_read_only=True):
        from omegaconf import OmegaConf  # noqa

        if is_read_only:
            description = \
                'Full OmegaConf YAML configuration. ' \
                'This is a read-only section, unless \'{}/{}\' is set to True'.format(
                    PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
        else:
            description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
                PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)

        configuration = dict(
            name=PatchHydra._config_section,
            description=description,
            config_type='OmegaConf YAML',
            config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )})
        )
        if PatchHydra._current_task:
            # noinspection PyProtectedMember
            PatchHydra._current_task._set_configuration(**configuration)
            PatchHydra._last_untracked_state.pop('_set_configuration', None)
        else:
            PatchHydra._last_untracked_state['_set_configuration'] = configuration


def __global_hydra_bind():
    # noinspection PyBroadException
    try:
        import hydra  # noqa
        PatchHydra.patch_hydra()
    except Exception:
        pass


# patch hydra
__global_hydra_bind()