import logging
import sys
import threading

from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.task.log import TaskHandler
from ..config import running_remotely


class StdStreamPatch(object):
    _stdout_proxy = None
    _stderr_proxy = None
    _stdout_original_write = None

    @staticmethod
    def patch_std_streams(logger):
        if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
            StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
            StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
            logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
            # noinspection PyBroadException
            try:
                if StdStreamPatch._stdout_original_write is None:
                    StdStreamPatch._stdout_original_write = sys.stdout.write
                # this will only work in python 3, guard it with try/catch
                if not hasattr(sys.stdout, '_original_write'):
                    sys.stdout._original_write = sys.stdout.write
                sys.stdout.write = StdStreamPatch._stdout__patched__write__
                if not hasattr(sys.stderr, '_original_write'):
                    sys.stderr._original_write = sys.stderr.write
                sys.stderr.write = StdStreamPatch._stderr__patched__write__
            except Exception:
                pass
            sys.stdout = StdStreamPatch._stdout_proxy
            sys.stderr = StdStreamPatch._stderr_proxy
            # patch the base streams of sys (this way colorama will keep its ANSI colors)
            # noinspection PyBroadException
            try:
                sys.__stderr__ = sys.stderr
            except Exception:
                pass
            # noinspection PyBroadException
            try:
                sys.__stdout__ = sys.stdout
            except Exception:
                pass

            # now check if we have loguru and make it re-register the handlers
            # because it sores internally the stream.write function, which we cant patch
            # noinspection PyBroadException
            try:
                from loguru import logger
                register_stderr = None
                register_stdout = None
                for k, v in logger._handlers.items():
                    if v._name == '<stderr>':
                        register_stderr = k
                    elif v._name == '<stdout>':
                        register_stderr = k
                if register_stderr is not None:
                    logger.remove(register_stderr)
                    logger.add(sys.stderr)
                if register_stdout is not None:
                    logger.remove(register_stdout)
                    logger.add(sys.stdout)
            except Exception:
                pass

        elif DevWorker.report_stdout and not running_remotely():
            logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
            if StdStreamPatch._stdout_proxy:
                StdStreamPatch._stdout_proxy.connect(logger)
            if StdStreamPatch._stderr_proxy:
                StdStreamPatch._stderr_proxy.connect(logger)

    @staticmethod
    def remove_std_logger():
        if isinstance(sys.stdout, PrintPatchLogger):
            # noinspection PyBroadException
            try:
                sys.stdout.connect(None)
            except Exception:
                pass
        if isinstance(sys.stderr, PrintPatchLogger):
            # noinspection PyBroadException
            try:
                sys.stderr.connect(None)
            except Exception:
                pass

    @staticmethod
    def stdout_original_write(*args, **kwargs):
        if StdStreamPatch._stdout_original_write:
            StdStreamPatch._stdout_original_write(*args, **kwargs)

    @staticmethod
    def _stdout__patched__write__(*args, **kwargs):
        if StdStreamPatch._stdout_proxy:
            return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
        return sys.stdout._original_write(*args, **kwargs)

    @staticmethod
    def _stderr__patched__write__(*args, **kwargs):
        if StdStreamPatch._stderr_proxy:
            return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
        return sys.stderr._original_write(*args, **kwargs)


class PrintPatchLogger(object):
    """
    Allowed patching a stream into the logger.
    Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
    """
    patched = False
    lock = threading.Lock()
    recursion_protect_lock = threading.RLock()

    def __init__(self, stream, logger=None, level=logging.INFO):
        PrintPatchLogger.patched = True
        self._terminal = stream
        self._log = logger
        self._log_level = level
        self._cur_line = ''

    def write(self, message):
        # make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
        if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
            try:
                self.lock.acquire()
                with PrintPatchLogger.recursion_protect_lock:
                    if hasattr(self._terminal, '_original_write'):
                        self._terminal._original_write(message)
                    else:
                        self._terminal.write(message)

                do_flush = '\n' in message
                do_cr = '\r' in message
                self._cur_line += message
                if (not do_flush and not do_cr) or not message:
                    return
                last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
                next_line = self._cur_line[last_lf + 1:]
                cur_line = self._cur_line[:last_lf + 1].rstrip()
                self._cur_line = next_line
            finally:
                self.lock.release()

            if cur_line:
                with PrintPatchLogger.recursion_protect_lock:
                    # noinspection PyBroadException
                    try:
                        if self._log:
                            self._log._console(cur_line, level=self._log_level, omit_console=True)
                    except Exception:
                        # what can we do, nothing
                        pass
        else:
            if hasattr(self._terminal, '_original_write'):
                self._terminal._original_write(message)
            else:
                self._terminal.write(message)

    def connect(self, logger):
        self._cur_line = ''
        self._log = logger

    def __getattr__(self, attr):
        if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
            return self.__dict__.get(attr)
        return getattr(self._terminal, attr)

    def __setattr__(self, key, value):
        if key in ['_log', '_terminal', '_log_level', '_cur_line']:
            self.__dict__[key] = value
        else:
            return setattr(self._terminal, key, value)


class LogFlusher(threading.Thread):
    def __init__(self, logger, period, **kwargs):
        super(LogFlusher, self).__init__(**kwargs)
        self.daemon = True

        self._period = period
        self._logger = logger
        self._exit_event = threading.Event()

    @property
    def period(self):
        return self._period

    def run(self):
        self._logger.flush()
        # store original wait period
        while True:
            period = self._period
            while not self._exit_event.wait(period or 1.0):
                self._logger.flush()
            # check if period is negative or None we should exit
            if self._period is None or self._period < 0:
                break
            # check if period was changed, we should restart
            self._exit_event.clear()

    def exit(self):
        self._period = None
        self._exit_event.set()

    def set_period(self, period):
        self._period = period
        # make sure we exit the previous wait
        self._exit_event.set()