import logging
import sys
import threading
from time import time

from ..binding.frameworks import _patched_call  # noqa
from ..config import running_remotely, config


class StdStreamPatch(object):
    _stdout_proxy = None
    _stderr_proxy = None
    _stdout_original_write = None
    _stderr_original_write = None

    @staticmethod
    def patch_std_streams(a_logger, connect_stdout=True, connect_stderr=True):
        if (connect_stdout or connect_stderr) and not PrintPatchLogger.patched and not running_remotely():
            StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, a_logger, level=logging.INFO) \
                if connect_stdout else None
            StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, a_logger, level=logging.ERROR) \
                if connect_stderr else None

            if StdStreamPatch._stdout_proxy:
                # noinspection PyBroadException
                try:
                    if StdStreamPatch._stdout_original_write is None:
                        StdStreamPatch._stdout_original_write = sys.stdout.write

                    # this will only work in python 3, guard it with try/catch
                    if not hasattr(sys.stdout, '_original_write'):
                        sys.stdout._original_write = sys.stdout.write
                    sys.stdout.write = StdStreamPatch._stdout__patched__write__
                except Exception:
                    pass
                sys.stdout = StdStreamPatch._stdout_proxy
                # noinspection PyBroadException
                try:
                    sys.__stdout__ = sys.stdout
                except Exception:
                    pass

            if StdStreamPatch._stderr_proxy:
                # noinspection PyBroadException
                try:
                    if StdStreamPatch._stderr_original_write is None:
                        StdStreamPatch._stderr_original_write = sys.stderr.write
                    if not hasattr(sys.stderr, '_original_write'):
                        sys.stderr._original_write = sys.stderr.write
                    sys.stderr.write = StdStreamPatch._stderr__patched__write__
                except Exception:
                    pass
                sys.stderr = StdStreamPatch._stderr_proxy

                # patch the base streams of sys (this way colorama will keep its ANSI colors)
                # noinspection PyBroadException
                try:
                    sys.__stderr__ = sys.stderr
                except Exception:
                    pass

            # now check if we have loguru and make it re-register the handlers
            # because it stores internally the stream.write function, which we cant patch
            # noinspection PyBroadException
            try:
                from loguru import logger  # noqa
                register_stderr = None
                register_stdout = None
                for k, v in logger._handlers.items():  # noqa
                    if connect_stderr and v._name == '<stderr>':  # noqa
                        register_stderr = k
                    elif connect_stdout and v._name == '<stdout>':  # noqa
                        register_stderr = k
                if register_stderr is not None:
                    logger.remove(register_stderr)
                    logger.add(sys.stderr)
                if register_stdout is not None:
                    logger.remove(register_stdout)
                    logger.add(sys.stdout)
            except Exception:
                pass

        elif (connect_stdout or connect_stderr) and not running_remotely():
            if StdStreamPatch._stdout_proxy and connect_stdout:
                StdStreamPatch._stdout_proxy.connect(a_logger)
            if StdStreamPatch._stderr_proxy and connect_stderr:
                StdStreamPatch._stderr_proxy.connect(a_logger)

    @staticmethod
    def patch_logging_formatter(a_logger, logging_handler=None):
        if not logging_handler:
            import logging
            logging_handler = logging.Handler
        logging_handler.format = _patched_call(logging_handler.format, HandlerFormat(a_logger))

    @staticmethod
    def remove_patch_logging_formatter(logging_handler=None):
        if not logging_handler:
            import logging
            logging_handler = logging.Handler
        # remove the function, Hack calling patched logging.Handler.format() returns the original function
        # noinspection PyBroadException
        try:
            logging_handler.format = logging_handler.format()  # noqa
        except Exception:
            pass

    @staticmethod
    def remove_std_logger(logger=None):
        if isinstance(sys.stdout, PrintPatchLogger):
            # noinspection PyBroadException
            try:
                sys.stdout.disconnect(logger)
            except Exception:
                pass
        if isinstance(sys.stderr, PrintPatchLogger):
            # noinspection PyBroadException
            try:
                sys.stderr.disconnect(logger)
            except Exception:
                pass

    @staticmethod
    def stdout_original_write(*args, **kwargs):
        if StdStreamPatch._stdout_original_write:
            StdStreamPatch._stdout_original_write(*args, **kwargs)
        else:
            sys.stdout.write(*args, **kwargs)

    @staticmethod
    def stderr_original_write(*args, **kwargs):
        if StdStreamPatch._stderr_original_write:
            StdStreamPatch._stderr_original_write(*args, **kwargs)
        else:
            sys.stderr.write(*args, **kwargs)

    @staticmethod
    def _stdout__patched__write__(*args, **kwargs):
        if StdStreamPatch._stdout_proxy:
            return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
        return sys.stdout._original_write(*args, **kwargs)  # noqa

    @staticmethod
    def _stderr__patched__write__(*args, **kwargs):
        if StdStreamPatch._stderr_proxy:
            return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
        return sys.stderr._original_write(*args, **kwargs)  # noqa


class HandlerFormat(object):
    def __init__(self, logger):
        self._logger = logger

    def __call__(self, original_format_func, *args):
        # hack get back original function, so we can remove it
        if all(a is None for a in args):
            return original_format_func
        if len(args) == 1:
            record = args[0]
            msg = original_format_func(record)
        else:
            handler = args[0]
            record = args[1]
            msg = original_format_func(handler, record)

        self._logger.report_text(msg=msg, level=record.levelno, print_console=False)
        return msg


class PrintPatchLogger(object):
    """
    Allowed patching a stream into the logger.
    Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
    """
    patched = False
    lock = threading.Lock()
    recursion_protect_lock = threading.RLock()
    cr_flush_period = config.get("development.worker.console_cr_flush_period", 0)

    def __init__(self, stream, logger=None, level=logging.INFO):
        PrintPatchLogger.patched = True
        self._terminal = stream
        self._log = logger
        self._log_level = level
        self._cur_line = ''
        self._force_lf_flush = False

    def write(self, message):
        # make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
        if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():  # noqa
            try:
                self.lock.acquire()
                with PrintPatchLogger.recursion_protect_lock:
                    if hasattr(self._terminal, '_original_write'):
                        self._terminal._original_write(message)  # noqa
                    else:
                        self._terminal.write(message)

                do_flush = '\n' in message
                do_cr = '\r' in message
                self._cur_line += message

                if not do_flush and do_cr and PrintPatchLogger.cr_flush_period and self._force_lf_flush:
                    self._cur_line += '\n'
                    do_flush = True

                self._force_lf_flush = False

                if (not do_flush and (PrintPatchLogger.cr_flush_period or not do_cr)) or not message:
                    return

                if PrintPatchLogger.cr_flush_period and self._cur_line:
                    self._cur_line = '\n'.join(line.split('\r')[-1] for line in self._cur_line.split('\n'))

                last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
                next_line = self._cur_line[last_lf + 1:]
                cur_line = self._cur_line[:last_lf + 1].rstrip()
                self._cur_line = next_line
            finally:
                self.lock.release()

            if cur_line:
                with PrintPatchLogger.recursion_protect_lock:
                    # noinspection PyBroadException
                    try:
                        if self._log:
                            # noinspection PyProtectedMember
                            self._log._console(cur_line, level=self._log_level, omit_console=True)
                    except Exception:
                        # what can we do, nothing
                        pass
        else:
            if hasattr(self._terminal, '_original_write'):
                self._terminal._original_write(message)  # noqa
            else:
                self._terminal.write(message)

    def connect(self, logger):
        self._cur_line = ''
        self._log = logger

    def disconnect(self, logger=None):
        # disconnect the logger only if it was registered
        if not logger or self._log == logger:
            self.connect(None)

    def force_lf_flush(self):
        self._force_lf_flush = True

    def __getattr__(self, attr):
        if attr in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
            return self.__dict__.get(attr)
        return getattr(self._terminal, attr)

    def __setattr__(self, key, value):
        if key in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
            self.__dict__[key] = value
        else:
            return setattr(self._terminal, key, value)


class LogFlusher(threading.Thread):
    def __init__(self, logger, period, **kwargs):
        super(LogFlusher, self).__init__(**kwargs)
        self.daemon = True

        self._period = period
        self._logger = logger
        self._exit_event = threading.Event()
        self._cr_last_flush = 0
        try:
            self._cr_flush_period = float(PrintPatchLogger.cr_flush_period)
        except (ValueError, TypeError):
            self._cr_flush_period = 0

    @property
    def period(self):
        return self._period

    def run(self):
        self._logger.flush()
        # store original wait period
        while True:
            period = self._period
            while not self._exit_event.wait(period or 1.0):
                if self._cr_flush_period and time() - self._cr_last_flush > self._cr_flush_period:
                    if isinstance(sys.stdout, PrintPatchLogger):
                        sys.stdout.force_lf_flush()
                    if isinstance(sys.stderr, PrintPatchLogger):
                        sys.stderr.force_lf_flush()
                    self._cr_last_flush = time()
                # now signal the real flush
                self._logger.flush()

            # check if period is negative or None we should exit
            if self._period is None or self._period < 0:
                break
            # check if period was changed, we should restart
            self._exit_event.clear()

    def exit(self):
        self._period = None
        self._exit_event.set()

    def set_period(self, period):
        self._period = period
        # make sure we exit the previous wait
        self._exit_event.set()