import abc
import re
import shutil
import subprocess
from distutils.spawn import find_executable
from hashlib import md5
from os import environ
from random import random
from threading import Lock
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple, Optional

import attr
from furl import furl
from pathlib2 import Path

import six

from clearml_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS, ENV_AGENT_GIT_HOST
from clearml_agent.helper.console import ensure_text, ensure_binary
from clearml_agent.errors import CommandFailedError
from clearml_agent.helper.base import (
    select_for_platform,
    rm_tree,
    ExecutionInfo,
    normalize_path,
    create_file_if_not_exists,
)
from clearml_agent.helper.os.locks import FileLock
from clearml_agent.helper.process import DEVNULL, Argv, PathLike, COMMAND_SUCCESS
from clearml_agent.session import Session


class VcsFactory(object):
    """
    Creates VCS instances
    """

    GIT_SUFFIX = ".git"

    @classmethod
    def create(cls, session, execution_info, location):
        # type: (Session, ExecutionInfo, PathLike) -> VCS
        """
        Create a VCS instance for config and url
        :param session: program session
        :param execution_info: task ExecutionInfo
        :param location: (desired) clone location
        """
        url = execution_info.repository
        # We only support git, hg is deprecated
        is_git = True
        # is_git = url.endswith(cls.GIT_SUFFIX)
        vcs_cls = Git if is_git else Hg
        revision = (
            execution_info.version_num
            or execution_info.tag
            or vcs_cls.remote_branch_name(execution_info.branch or vcs_cls.main_branch)
        )
        return vcs_cls(session, url, location, revision)


# noinspection PyUnresolvedReferences
@attr.s
class RepoInfo(object):
    """
    Cloned repository information
    :param type: VCS type
    :param url: repository url
    :param branch: revision branch
    :param commit: revision number
    :param root: clone location path
    """

    type = attr.ib(type=str)
    url = attr.ib(type=str)
    branch = attr.ib(type=str)
    commit = attr.ib(type=str)
    root = attr.ib(type=str)


RType = TypeVar("RType")


@six.add_metaclass(abc.ABCMeta)
class VCS(object):

    """
    Provides overloaded utilities for handling repositories of different types
    """

    # additional environment variables for VCS commands
    COMMAND_ENV = {}

    PATCH_ADDED_FILE_RE = re.compile(r"^--- a/(?P<path>.*)")

    def __init__(self, session, url, location, revision):
        # type: (Session, Text, PathLike, Text) -> ()
        """
        Create a VCS instance for config and url
        :param session: program session
        :param url: repository url
        :param location: (desired) clone location
        :param revision: desired clone revision
        """
        self.session = session
        self.log = self.session.get_logger(
            "{}.{}".format(__name__, type(self).__name__)
        )
        self.url = url
        self.location = Text(location)
        self._revision = revision
        self.log = self.session.get_logger(__name__)

    @property
    def url_with_auth(self):
        """
        Return URL with configured user/password
        """
        return self.add_auth(self.session.config, self.url)

    @abc.abstractmethod
    def executable_name(self):
        """
        Name of command executable
        """
        pass

    @abc.abstractmethod
    def main_branch(self):
        """
        Name of default/main branch
        """
        pass

    @abc.abstractmethod
    def checkout_flags(self):
        # type: () -> Sequence[Text]
        """
        Command-line flags for checkout command
        """
        pass

    @abc.abstractmethod
    def patch_base(self):
        # type: () -> Sequence[Text]
        """
        Command and flags for applying patches
        """
        pass

    def patch(self, location, patch_content):
        # type: (PathLike, Text) -> bool
        """
        Apply patch repository at `location`
        """
        self.log.info("applying diff to %s" % location)

        # noinspection PyBroadException
        try:
            for match in filter(
                None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
            ):
                file_path = None
                # noinspection PyBroadException
                try:
                    file_path = normalize_path(location, match.group("path"))
                    create_file_if_not_exists(file_path)
                except Exception:
                    if file_path:
                        self.log.warning("failed creating file for git diff (%s)" % file_path)
        except Exception:
            pass

        return_code, errors = self.call_with_stdin(
            patch_content, *self.patch_base, cwd=location
        )
        if return_code:
            self.log.error("Failed applying diff")
            lines = errors.splitlines()
            if any(l for l in lines if "no such file or directory" in l.lower()):
                self.log.warning(
                    "NOTE: files were not found when applying diff, perhaps you forgot to push your changes?"
                )
            return False
        else:
            self.log.info("successfully applied uncommitted changes")
        return True

    # Command-line flags for clone command
    clone_flags = ()

    @abc.abstractmethod
    def executable_not_found_error_help(self):
        # type: () -> Text
        """
        Instructions for when executable is not found
        """
        pass

    @staticmethod
    def remote_branch_name(branch):
        # type: (Text) -> Text
        """
        Creates name of remote branch from name of local/ambiguous branch.
        Returns same name by default.
        """
        return branch

    # parse scp-like git ssh URLs, e.g: git@host:user/project.git
    SSH_URL_GIT_SYNTAX = re.compile(
        r"""
        ^
        (?:(?P<user>{regular}*?)@)?
        (?P<host>{regular}*?)
        :
        (?P<path>{regular}.*)?
        $
        """.format(
            regular=r"[^/@:#]"
        ),
        re.VERBOSE,
    )

    @classmethod
    def replace_ssh_url(cls, url):
        # type: (Text) -> Text
        """
        Replace SSH URL with HTTPS URL when applicable
        """

        def get_username(user_, password=None):
            """
            Remove special SSH users hg/git
            """
            return (
                None
                if user_ and user_.lower() in ["hg", "git"] and not password
                else user_
            )

        match = cls.SSH_URL_GIT_SYNTAX.match(url)
        if match:
            user, host, path = match.groups()
            return (
                furl()
                .set(scheme="https", username=get_username(user), host=host, path=path)
                .url
            )
        parsed_url = furl(url)
        if parsed_url.scheme == "ssh":
            return parsed_url.set(
                scheme="https",
                username=get_username(
                    parsed_url.username, password=parsed_url.password
                ),
            ).url
        return url

    @classmethod
    def replace_http_url(cls, url, port=None, username=None):
        # type: (Text, Optional[int], Optional[str]) -> Text
        """
        Replace HTTPS URL with SSH URL when applicable
        """
        parsed_url = furl(url)
        if parsed_url.scheme == "https":
            parsed_url.scheme = "ssh"
            parsed_url.username = username or "git"
            parsed_url.password = None
            # make sure there is no port in the final url (safe_furl support)
            # the original port was an https port, and we do not know if there is a different ssh port,
            # so we have to clear the original port specified (https) and use the default ssh schema port.
            parsed_url.port = port or None
            url = parsed_url.url
        return url

    @classmethod
    def rewrite_ssh_url(cls, url, port=None, username=None):
        # type: (Text, Optional[int], Optional[str]) -> Text
        """
        Rewrite SSH URL with custom port and username
        """
        parsed_url = furl(url)
        if parsed_url.scheme == "ssh":
            parsed_url.username = username or "git"
            parsed_url.port = port or None
            return parsed_url.url

    def _set_ssh_url(self):
        """
        Replace instance URL with SSH substitution result and report to log.
        According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
        """
        if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
            parsed_url = furl(self.url)
            # only apply to a specific domain (if requested)
            config_domain = \
                ENV_AGENT_GIT_HOST.get() or self.session.config.get("agent.git_host", None)
            if config_domain and config_domain != parsed_url.host:
                return
            if parsed_url.scheme == "https":
                new_url = self.replace_http_url(
                    self.url,
                    port=self.session.config.get('agent.force_git_ssh_port', None),
                    username=self.session.config.get('agent.force_git_ssh_user', None)
                )
                if new_url != self.url:
                    print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
                        self.url, new_url))
                    self.url = new_url
                return
            # rewrite ssh URLs only if either ssh port or ssh user are forced in config
            if parsed_url.scheme == "ssh" and (
                self.session.config.get('agent.force_git_ssh_port', None) or
                self.session.config.get('agent.force_git_ssh_user', None)
            ):
                new_url = self.rewrite_ssh_url(
                    self.url,
                    port=self.session.config.get('agent.force_git_ssh_port', None),
                    username=self.session.config.get('agent.force_git_ssh_user', None)
                )
                if new_url != self.url:
                    print("Using SSH credentials - ssh url '{}' with ssh url '{}'".format(
                        self.url, new_url))
                    self.url = new_url

        if not self.session.config.agent.translate_ssh:
            return

        # if we have git_user / git_pass replace ssh credentials with https authentication
        if (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and \
                (ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)):
            # only apply to a specific domain (if requested)
            config_domain = \
                ENV_AGENT_GIT_HOST.get() or self.session.config.get("git_host", None)
            if config_domain and config_domain != furl(self.url).host:
                return

            new_url = self.replace_ssh_url(self.url)
            if new_url != self.url:
                print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
                    self.url, new_url))
                self.url = new_url

    def clone(self, branch=None):
        # type: (Text) -> None
        """
        Clone repository to destination and checking out `branch`.
        If not in debug mode, filter VCS password from output.
        """
        self._set_ssh_url()
        clone_command = ("clone", self.url_with_auth, self.location) + self.clone_flags
        # clone all branches regardless of when we want to later checkout
        # if branch:
        #    clone_command += ("-b", branch)
        if self.session.debug_mode:
            self.call(*clone_command)
            return

        def normalize_output(result):
            """
            Returns result string without user's password.
            NOTE: ``self.get_stderr``'s result might or might not have the same type as ``e.output`` in case of error.
            """
            string_type = (
                ensure_text
                if isinstance(result, six.text_type)
                else ensure_binary
            )
            return result.replace(
                string_type(self.url),
                string_type(furl(self.url).remove(password=True).tostr()),
            )

        def print_output(output):
            print(ensure_text(output))

        try:
            print_output(normalize_output(self.get_stderr(*clone_command)))
        except subprocess.CalledProcessError as e:
            # In Python 3, subprocess.CalledProcessError has a `stderr` attribute,
            # but since stderr is redirect to `subprocess.PIPE` it will appear in the usual `output` attribute
            if e.output:
                e.output = normalize_output(e.output)
                print_output(e.output)
            raise

    def checkout(self):
        # type: () -> None
        """
        Checkout repository at specified revision
        """
        self.call("checkout", self._revision, *self.checkout_flags, cwd=self.location)

    @abc.abstractmethod
    def pull(self):
        # type: () -> None
        """
        Pull remote changes for revision
        """
        pass

    def call(self, *argv, **kwargs):
        """
        Execute argv without stdout/stdin.
        Remove stdin so git/hg can't ask for passwords.
        ``kwargs`` can override all arguments passed to subprocess.
        """
        return self._call_subprocess(subprocess.check_call, argv, **kwargs)

    def call_with_stdin(self, input_, *argv, **kwargs):
        # type: (...) -> Tuple[int, str]
        """
        Run command with `input_` as stdin
        """
        input_ = input_.encode("utf-8")
        # always add extra empty line 
        # (there is no downside, and it solves empty lines issue at end of patch cause corrupt message)
        input_ += b"\n"
        process = self._call_subprocess(
            subprocess.Popen,
            argv,
            **dict(
                kwargs,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
        )
        _, stderr = process.communicate(input_)
        if stderr:
            self.log.warning("%s: %s", self._get_vcs_command(argv), stderr)
        return process.returncode, Text(stderr)

    def get_stderr(self, *argv, **kwargs):
        """
        Execute argv without stdout/stdin in <cwd> and get stderr output.
        Remove stdin so git/hg can't ask for passwords.
        ``kwargs`` can override all arguments passed to subprocess.
        """
        process = self._call_subprocess(
            subprocess.Popen, argv, **dict(kwargs, stderr=subprocess.PIPE, stdout=None)
        )
        _, stderr = process.communicate()
        code = process.poll()
        if code == COMMAND_SUCCESS:
            return stderr
        with Argv.normalize_exception(censor_password=True):
            raise subprocess.CalledProcessError(
                returncode=code, cmd=argv, output=stderr
            )

    def _call_subprocess(self, func, argv, **kwargs):
        # type: (Callable[..., RType], Iterable[Text], dict) ->  RType
        cwd = kwargs.pop("cwd", None)
        cwd = cwd and str(cwd)
        kwargs = dict(
            dict(
                censor_password=True,
                cwd=cwd,
                stdin=DEVNULL,
                stdout=DEVNULL,
                env=dict(self.COMMAND_ENV, **environ),
            ),
            **kwargs
        )
        command = self._get_vcs_command(argv)
        self.log.debug("Running: %s", list(command))
        return command.call_subprocess(func, **kwargs)

    def _get_vcs_command(self, argv):
        # type: (Iterable[PathLike]) -> Argv
        return Argv(self.executable_name, *argv)

    @classmethod
    def add_auth(cls, config, url):
        """
        Add username and password to URL if missing from URL and present in config.
        Does not modify ssh URLs.
        """
        try:
            parsed_url = furl(url)
        except ValueError:
            return url
        if parsed_url.scheme in ["", "ssh"] or (parsed_url.scheme or '').startswith("git"):
            return parsed_url.url
        config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)
        config_pass = ENV_AGENT_GIT_PASS.get() or config.get("agent.{}_pass".format(cls.executable_name), None)
        config_domain = ENV_AGENT_GIT_HOST.get() or config.get("agent.{}_host".format(cls.executable_name), None)
        if (
            (not (parsed_url.username and parsed_url.password))
            and config_user
            and config_pass
            and (not config_domain or config_domain.lower() == parsed_url.host)
        ):
            parsed_url.set(username=config_user, password=config_pass)
        return parsed_url.url

    @abc.abstractmethod
    def info_commands(self):
        # type: () -> Mapping[Text, Argv]
        """
    `   Mapping from `RepoInfo` attribute name (except `type`) to command which acquires it
        """
        pass

    def get_repository_copy_info(self, path):
        """
        Get `RepoInfo` instance from copy of clone in `path`
        """
        path = Text(path)
        commands_result = {
            name: command.get_output(cwd=path)
            # name: subprocess.check_output(command.split(), cwd=path).decode().strip()
            for name, command in self.info_commands.items()
        }
        return RepoInfo(type=self.executable_name, **commands_result)


class Git(VCS):
    executable_name = "git"
    main_branch = ("master", "main")
    clone_flags = ("--quiet", "--recursive")
    checkout_flags = ("--force",)
    COMMAND_ENV = {
        # do not prompt for password
        "GIT_TERMINAL_PROMPT": "0",
        # do not prompt for ssh key passphrase
        "GIT_SSH_COMMAND": "ssh -oBatchMode=yes",
    }

    @staticmethod
    def remote_branch_name(branch):
        return [
            "origin/{}".format(b) for b in ([branch] if isinstance(branch, str) else branch)
        ]

    def executable_not_found_error_help(self):
        return 'Cannot find "{}" executable. {}'.format(
            self.executable_name,
            select_for_platform(
                linux="You can install it by running: sudo apt-get install {}".format(
                    self.executable_name
                ),
                windows="You can download it here: {}".format(
                    "https://gitforwindows.org/"
                ),
            ),
        )

    def pull(self):
        self.call("fetch", "--all", "--recurse-submodules", cwd=self.location)

    def checkout(self):  # type: () -> None
        """
        Checkout repository at specified revision
        """
        revisions = [self._revision] if isinstance(self._revision, str) else self._revision
        for i, revision in enumerate(revisions):
            try:
                self.call("checkout", revision, *self.checkout_flags, cwd=self.location)
                break
            except subprocess.CalledProcessError:
                if i == len(revisions) - 1:
                    raise

        try:
            self.call("submodule", "update", "--recursive", cwd=self.location)
        except:  # noqa
            pass

    info_commands = dict(
        url=Argv(executable_name, "ls-remote", "--get-url", "origin"),
        branch=Argv(executable_name, "rev-parse", "--abbrev-ref", "HEAD"),
        commit=Argv(executable_name, "rev-parse", "HEAD"),
        root=Argv(executable_name, "rev-parse", "--show-toplevel"),
    )

    patch_base = ("apply", "--unidiff-zero", )


class Hg(VCS):
    executable_name = "hg"
    main_branch = "default"
    checkout_flags = ("--clean",)
    patch_base = ("import", "--no-commit")

    def executable_not_found_error_help(self):
        return 'Cannot find "{}" executable. {}'.format(
            self.executable_name,
            select_for_platform(
                linux="You can install it by running: sudo apt-get install {}".format(
                    self.executable_name
                ),
                windows="You can download it here: {}".format(
                    "https://www.mercurial-scm.org/wiki/Download"
                ),
            ),
        )

    def pull(self):
        self.call(
            "pull",
            self.url_with_auth,
            cwd=self.location,
            *(("-r", self._revision) if self._revision else ())
        )

    info_commands = dict(
        url=Argv(executable_name, "paths", "--verbose"),
        branch=Argv(executable_name, "--debug", "id", "-b"),
        commit=Argv(executable_name, "--debug", "id", "-i"),
        root=Argv(executable_name, "root"),
    )


def clone_repository_cached(session, execution, destination):
    # type: (Session, ExecutionInfo, Path) -> Tuple[VCS, RepoInfo]
    """
    Clone a remote repository.
    :param execution: execution info
    :param destination: directory to clone to (in which a directory for the repository will be created)
    :param session: program session
    :return: repository information
    :raises: CommandFailedError if git/hg is not installed
    """
    # mock lock
    repo_lock = Lock()
    repo_lock_timeout_sec = 300
    repo_url = execution.repository or ''  # type: str
    parsed_url = furl(repo_url)
    no_password_url = parsed_url.copy().remove(password=True).url

    clone_folder_name = Path(str(furl(repo_url).path)).name  # type: str
    clone_folder = Path(destination) / clone_folder_name

    standalone_mode = session.config.get("agent.standalone_mode", False)
    if standalone_mode:
        cached_repo_path = clone_folder
    else:
        vcs_cache_path = Path(session.config["agent.vcs_cache.path"]).expanduser()
        repo_hash = md5(ensure_binary(repo_url)).hexdigest()
        # create lock
        repo_lock = FileLock(filename=(vcs_cache_path / '{}.lock'.format(repo_hash)).as_posix())
        # noinspection PyBroadException
        try:
            repo_lock.acquire(timeout=repo_lock_timeout_sec)
        except BaseException:
            print('Could not lock cache folder "{}" (timeout {} sec), using temp vcs cache.'.format(
                clone_folder_name, repo_lock_timeout_sec))
            repo_hash = '{}_{}'.format(repo_hash, str(random()).replace('.', ''))
            # use mock lock for the context
            repo_lock = Lock()
        # select vcs cache folder
        cached_repo_path = vcs_cache_path / "{}.{}".format(clone_folder_name, repo_hash) / clone_folder_name

    with repo_lock:
        vcs = VcsFactory.create(
            session, execution_info=execution, location=cached_repo_path
        )
        if not find_executable(vcs.executable_name):
            raise CommandFailedError(vcs.executable_not_found_error_help())

        if not standalone_mode:
            if session.config["agent.vcs_cache.enabled"] and cached_repo_path.exists():
                print('Using cached repository in "{}"'.format(cached_repo_path))

            else:
                print("cloning: {}".format(no_password_url))
                rm_tree(cached_repo_path)
                # We clone the entire repository, not a specific branch
                vcs.clone()  # branch=execution.branch)

            vcs.pull()
            rm_tree(destination)
            shutil.copytree(Text(cached_repo_path), Text(clone_folder),
                            symlinks=select_for_platform(linux=True, windows=False),
                            ignore_dangling_symlinks=True)
            if not clone_folder.is_dir():
                raise CommandFailedError(
                    "copying of repository failed: from {} to {}".format(
                        cached_repo_path, clone_folder
                    )
                )

            # checkout in the newly copy destination
            vcs.location = Text(clone_folder)
            vcs.checkout()

    repo_info = vcs.get_repository_copy_info(clone_folder)

    # make sure we have no user/pass in the returned repository structure
    repo_info = attr.evolve(repo_info, url=no_password_url)

    return vcs, repo_info


def fix_package_import_diff_patch(entry_script_file):
    # noinspection PyBroadException
    try:
        with open(entry_script_file, 'rt') as f:
            lines = f.readlines()
    except Exception:
        return
    # make sre we are the first import (i.e. we patched the source code)
    if not lines or not lines[0].strip().startswith('from clearml ') or 'Task.init' not in lines[1]:
        return

    original_lines = lines
    # skip over the first two lines, they are ours
    # then skip over empty or comment lines
    lines = [(i, line.split('#', 1)[0].rstrip()) for i, line in enumerate(lines)
             if i >= 2 and line.strip('\r\n\t ') and not line.strip().startswith('#')]

    # remove triple quotes ' """ '
    nested_c = -1
    skip_lines = []
    for i, line_pair in enumerate(lines):
        for _ in line_pair[1].split('"""')[1:]:
            if nested_c >= 0:
                skip_lines.extend(list(range(nested_c, i+1)))
                nested_c = -1
            else:
                nested_c = i
    # now select all the
    lines = [pair for i, pair in enumerate(lines) if i not in skip_lines]

    from_future = re.compile(r"^from[\s]*__future__[\s]*")
    import_future = re.compile(r"^import[\s]*__future__[\s]*")
    # test if we have __future__ import
    found_index = -1
    for a_i, (_, a_line) in enumerate(lines):
        if found_index >= a_i:
            continue
        if from_future.match(a_line) or import_future.match(a_line):
            found_index = a_i
            # check the last import block
            i, line = lines[found_index]
            # wither we have \\ character at the end of the line or the line is indented
            parenthesized_lines = '(' in line and ')' not in line
            while line.endswith('\\') or parenthesized_lines:
                found_index += 1
                i, line = lines[found_index]
                if ')' in line:
                    break

        else:
            break

    # no imports found
    if found_index < 0:
        return

    # now we need to move back the patched two lines
    entry_line = lines[found_index][0]
    new_lines = original_lines[2:entry_line + 1] + original_lines[0:2] + original_lines[entry_line + 1:]
    # noinspection PyBroadException
    try:
        with open(entry_script_file, 'wt') as f:
            f.writelines(new_lines)
    except Exception:
        return