from __future__ import absolute_import, unicode_literals import operator import os import re from abc import ABCMeta, abstractmethod from copy import deepcopy from itertools import chain, starmap from operator import itemgetter from os import path from typing import Text, List, Type, Optional, Tuple, Dict from pathlib2 import Path from pyhocon import ConfigTree import six from trains_agent.definitions import PIP_EXTRA_INDICES from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform from trains_agent.helper.process import Argv, PathLike from trains_agent.session import Session, normalize_cuda_version from trains_agent.external.requirements_parser import parse from trains_agent.external.requirements_parser.requirement import Requirement from .translator import RequirementsTranslator class SpecsResolutionError(Exception): pass class FatalSpecsResolutionError(Exception): pass @six.python_2_unicode_compatible class MarkerRequirement(object): def __init__(self, req): # type: (Requirement) -> None self.req = req @property def marker(self): match = re.search(r';\s*(.*)', self.req.line) if match: return match.group(1) return None def tostr(self, markers=True): if not self.uri: parts = [self.name or self.line] if self.extras: parts.append('[{0}]'.format(','.join(sorted(self.extras)))) if self.specifier: parts.append(self.format_specs()) elif self.vcs: # leave the line as is, let pip handle it if self.line: return self.line else: # let's build the line manually parts = [ self.uri, '@{}'.format(self.revision) if self.revision else '', '#subdirectory={}'.format(self.subdirectory) if self.subdirectory else '' ] else: parts = [self.uri] if markers and self.marker: parts.append('; {0}'.format(self.marker)) return ''.join(parts) __str__ = tostr def __repr__(self): return '{self.__class__.__name__}[{self}]'.format(self=self) def format_specs(self, num_parts=None, max_num_parts=None): max_num_parts = max_num_parts or num_parts if max_num_parts is None or not self.specs: return ','.join(starmap(operator.add, self.specs)) op, version = self.specs[0] for v in self._sub_versions_pep440: version = version.replace(v, '.') if num_parts: version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts] else: version = version.strip('.').split('.')[:max_num_parts] return op+'.'.join(version) def __getattr__(self, item): return getattr(self.req, item) @property def specs(self): # type: () -> List[Tuple[Text, Text]] return self.req.specs @specs.setter def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None self.req.specs = value def fix_specs(self): def solve_by(func, op_is, specs): return func([(op, version) for op, version in specs if op == op_is]) def solve_equal(specs): if len(set(version for _, version in self.specs)) > 1: raise SpecsResolutionError('more than one "==" spec: {}'.format(specs)) return specs greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs) smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs) equal = solve_by(solve_equal, '==', self.specs) if equal: self.specs = equal else: self.specs = greater + smaller def compare_version(self, requested_version, op=None, num_parts=3): """ compare the requested version with the one we have in the spec, If the requested version is 1.2.3 the self.spec should be 1.2.3* If the requested version is 1.2 the self.spec should be 1.2* etc. :param str requested_version: :param str op: '==', '>', '>=', '<=', '<', '~=' :param int num_parts: number of parts to compare :return: True if we answer the requested version """ # if we have no specific version, we cannot compare, so assume it's okay if not self.specs: return True version = self.specs[0][1] op = (op or self.specs[0][0]).strip() return SimpleVersion.compare_versions(requested_version, op, version) class SimpleVersion: _sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ] VERSION_PATTERN = r""" v? (?: (?:(?P[0-9]+)!)? # epoch (?P[0-9]+(?:\.[0-9]+)*) # release segment (?P
                                          # pre-release
                [-_\.]?
                (?P(a|b|c|rc|alpha|beta|pre|preview))
                [-_\.]?
                (?P[0-9]+)?
            )?
            (?P                                         # post release
                (?:-(?P[0-9]+))
                |
                (?:
                    [-_\.]?
                    (?Ppost|rev|r)
                    [-_\.]?
                    (?P[0-9]+)?
                )
            )?
            (?P                                          # dev release
                [-_\.]?
                (?Pdev)
                [-_\.]?
                (?P[0-9]+)?
            )?
        )
        (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
    """
    _local_version_separators = re.compile(r"[\._-]")
    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)

    @classmethod
    def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
        """
        Compare two versions based on the op operator
        returns bool(version_a op version_b)
        Notice: Ignores a/b/rc/post/dev markers on the version

        :param str version_a:
        :param str op: '==', '===', '>', '>=', '<=', '<', '~='
        :param str version_b:
        :param bool ignore_sub_versions: if true compare only major.minor.patch
            (ignore a/b/rc/post/dev in the comparison)
        :return bool: version_a op version_b
        """

        if not version_b:
            return True
        num_parts = 3

        if op == '~=':
            num_parts = max(num_parts, 2)
            op = '=='
            ignore_sub_versions = True
        elif op == '===':
            op = '=='

        try:
            version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
            version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
        except:
            # revert to string based
            for v in cls._sub_versions_pep440:
                version_a = version_a.replace(v, '.')
                version_b = version_b.replace(v, '.')

            version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
            version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
            version_a_key = ''
            version_b_key = ''
            for i in range(num_parts):
                pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
                version_a_key += pad.format(version_a[i])
                version_b_key += pad.format(version_b[i])

        if op == '==':
            return version_a_key == version_b_key
        if op == '<=':
            return version_a_key <= version_b_key
        if op == '>=':
            return version_a_key >= version_b_key
        if op == '>':
            return version_a_key > version_b_key
        if op == '<':
            return version_a_key < version_b_key
        raise ValueError('Unrecognized comparison operator [{}]'.format(op))

    @staticmethod
    def _parse_letter_version(
            letter,  # type: str
            number,  # type: Union[str, bytes, SupportsInt]
    ):
        # type: (...) -> Optional[Tuple[str, int]]

        if letter:
            # We consider there to be an implicit 0 in a pre-release if there is
            # not a numeral associated with it.
            if number is None:
                number = 0

            # We normalize any letters to their lower case form
            letter = letter.lower()

            # We consider some words to be alternate spellings of other words and
            # in those cases we want to normalize the spellings to our preferred
            # spelling.
            if letter == "alpha":
                letter = "a"
            elif letter == "beta":
                letter = "b"
            elif letter in ["c", "pre", "preview"]:
                letter = "rc"
            elif letter in ["rev", "r"]:
                letter = "post"

            return letter, int(number)
        if not letter and number:
            # We assume if we are given a number, but we are not given a letter
            # then this is using the implicit post release syntax (e.g. 1.0-1)
            letter = "post"

            return letter, int(number)

        return ()

    @staticmethod
    def _get_match_key(match, num_parts, ignore_sub_versions):
        if ignore_sub_versions:
            return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
                    (), (), (), (),)
        return (
            int(match.group("epoch")) if match.group("epoch") else 0,
            tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
            SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
            SimpleVersion._parse_letter_version(
                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
            ),
            SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
            SimpleVersion._parse_local_version(match.group("local")),
        )

    @staticmethod
    def _parse_local_version(local):
        # type: (str) -> Optional[LocalType]
        """
        Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
        """
        if local is not None:
            return tuple(
                part.lower() if not part.isdigit() else int(part)
                for part in SimpleVersion._local_version_separators.split(local)
            )
        return ()


@six.add_metaclass(ABCMeta)
class RequirementSubstitution(object):

    _pip_extra_index_url = PIP_EXTRA_INDICES

    def __init__(self, session):
        # type: (Session) -> ()
        self._session = session
        self.config = session.config  # type: ConfigTree
        self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
        self.package_manager = self.config['agent.package_manager.type']

    @abstractmethod
    def match(self, req):  # type: (MarkerRequirement) -> bool
        """
        Returns whether a requirement needs to be modified by this substitution.
        """
        pass

    @abstractmethod
    def replace(self, req):  # type: (MarkerRequirement) -> Text
        """
        Replace a requirement
        """
        pass

    def post_install(self, session):
        pass

    @classmethod
    def get_pip_version(cls, package):
        output = Argv(
            'pip',
            'search',
            package,
            *(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
        ).get_output()
        # ad-hoc pattern to duplicate the behavior of the old code
        return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)

    @property
    def cuda_version(self):
        return self.config['agent.cuda_version']

    @property
    def cudnn_version(self):
        return self.config['agent.cudnn_version']


class SimpleSubstitution(RequirementSubstitution):

    @property
    @abstractmethod
    def name(self):
        pass

    def match(self, req):  # type: (MarkerRequirement) -> bool
        return (self.name == req.name or (
            req.uri and
            re.match(r'https?://', req.uri) and
            self.name in req.uri
        ))

    def replace(self, req):  # type: (MarkerRequirement) -> Text
        """
        Replace a requirement
        :raises: ValueError if version is pre-release
        """
        if req.uri:
            return re.sub(
                r'({})(.*?)(-cp)'.format(self.name),
                r'\1\2{}\3'.format(self.suffix),
                req.uri,
                count=1)

        if req.specs:
            _, version_number = req.specs[0]
            # assert packaging_version.parse(version_number)
        else:
            version_number = self.get_pip_version(self.name)

        req.specs = [('==', version_number + self.suffix)]
        return Text(req)

    def replace_back(self, list_of_requirements):  # type: (Dict) -> Dict
        """
        :param list_of_requirements: {'pip': ['a==1.0', ]}
        :return: {'pip': ['a==1.0', ]}
        """
        return list_of_requirements


@six.add_metaclass(ABCMeta)
class CudaSensitiveSubstitution(SimpleSubstitution):

    def match(self, req):  # type: (MarkerRequirement) -> bool
        return self.cuda_version and self.cudnn_version and \
            super(CudaSensitiveSubstitution, self).match(req)


class CudaNotFound(Exception):
    pass


class RequirementsManager(object):

    def __init__(self, session, base_interpreter=None):
        # type: (Session, PathLike) -> ()
        self._session = session
        self.config = deepcopy(session.config)  # type: ConfigTree
        self.handlers = []  # type: List[RequirementSubstitution]
        agent = self.config['agent']
        self.active = not agent.get('cpu_only', False)
        self.found_cuda = False
        if self.active:
            try:
                agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
                self.found_cuda = True
            except Exception:
                # if we have a cuda version, it is good enough (we dont have to have cudnn version)
                if agent.get('cuda_version'):
                    self.found_cuda = True
        pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
            'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
        self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
                                                 cache_dir=pip_cache_dir.as_posix())

    def register(self, cls):  # type: (Type[RequirementSubstitution]) -> None
        self.handlers.append(cls(self._session))

    def _replace_one(self, req):  # type: (MarkerRequirement) -> Optional[Text]
        match = re.search(r';\s*(.*)', Text(req))
        if match:
            req.markers = match.group(1).split(',')
        if not self.active:
            return None
        for handler in self.handlers:
            if handler.match(req):
                return handler.replace(req)
        return None

    def replace(self, requirements):  # type: (Text) -> Text
        def safe_parse(req_str):
            try:
                return next(parse(req_str))
            except Exception as ex:
                return Requirement(req_str)

        parsed_requirements = tuple(
            map(
                MarkerRequirement,
                [safe_parse(line) for line in (requirements.splitlines()
                                               if isinstance(requirements, six.text_type) else requirements)]
            )
        )
        if not parsed_requirements:
            # return the original requirements just in case
            return requirements

        def replace_one(i, req):
            # type: (int, MarkerRequirement) -> Optional[Text]
            try:
                return self._replace_one(req)
            except FatalSpecsResolutionError:
                warning('could not resolve python wheel replacement for {}'.format(req))
                raise
            except Exception:
                warning('could not resolve python wheel replacement for \"{}\", '
                        'using original requirements line: {}'.format(req, i))
                return None

        new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
        conda = is_conda(self.config)
        result = map(
            lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
            new_requirements,
            parsed_requirements
        )
        if not conda:
            result = map(self.translator.translate, result)
        return join_lines(result)

    def post_install(self, session):
        for h in self.handlers:
            try:
                h.post_install(session)
            except Exception as ex:
                print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
                raise

    def replace_back(self, requirements):
        for h in self.handlers:
            try:
                requirements = h.replace_back(requirements)
            except Exception:
                pass
        return requirements

    @staticmethod
    def get_cuda_version(config):  # type: (ConfigTree) -> (Text, Text)
        # we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
        cuda_version = config['agent.cuda_version']
        cudnn_version = config['agent.cudnn_version']
        if cuda_version and cudnn_version:
            return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)

        if not cuda_version and is_windows_platform():
            try:
                cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
                             if k.startswith('CUDA_PATH_V')]
                cuda_vers = max(cuda_vers)
                if cuda_vers > 40:
                    cuda_version = cuda_vers
            except:
                pass

        if not cuda_version:
            try:
                try:
                    nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
                    if is_windows_platform() and 'CUDA_PATH' in os.environ:
                        nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)

                    output = Argv(nvcc, '--version').get_output()
                except OSError:
                    raise CudaNotFound('nvcc not found')
                match = re.search(r'release (.{3})', output).group(1)
                cuda_version = Text(int(float(match) * 10))
            except:
                pass

        if not cuda_version:
            try:
                try:
                    output = Argv('nvidia-smi',).get_output()
                except OSError:
                    raise CudaNotFound('nvcc not found')
                match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
                match = match.group(1)+'.'+match.group(2)
                cuda_version = Text(int(float(match) * 10))
            except:
                pass

        if not cudnn_version:
            try:
                cuda_lib = which('nvcc')
                if is_windows_platform:
                    cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
                else:
                    cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))

                cudnn_major, cudnn_minor = None, None
                try:
                    include_file = open(cudnn_h)
                except OSError:
                    raise CudaNotFound('Could not read cudnn.h')
                with include_file:
                    for line in include_file:
                        if 'CUDNN_MAJOR' in line:
                            cudnn_major = line.split()[-1]
                        if 'CUDNN_MINOR' in line:
                            cudnn_minor = line.split()[-1]
                        if cudnn_major and cudnn_minor:
                            break
                cudnn_version = cudnn_major + (cudnn_minor or '0')
            except:
                pass

        return (normalize_cuda_version(cuda_version or 0),
                normalize_cuda_version(cudnn_version or 0))