from __future__ import absolute_import, unicode_literals import operator import os import re from abc import ABCMeta, abstractmethod from copy import deepcopy, copy from itertools import chain, starmap from operator import itemgetter from os import path from typing import Text, List, Type, Optional, Tuple, Dict from pathlib2 import Path from pyhocon import ConfigTree import six from clearml_agent.definitions import PIP_EXTRA_INDICES from clearml_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform from clearml_agent.helper.process import Argv, PathLike from clearml_agent.session import Session, normalize_cuda_version from clearml_agent.external.requirements_parser import parse from clearml_agent.external.requirements_parser.requirement import Requirement from .translator import RequirementsTranslator class SpecsResolutionError(Exception): pass class FatalSpecsResolutionError(Exception): pass @six.python_2_unicode_compatible class MarkerRequirement(object): # if True pip version above 20.x and with support for "package @ scheme://link" # default is True pip_new_version = True def __init__(self, req): # type: (Requirement) -> None self.req = req @property def marker(self): match = re.search(r';\s*(.*)', self.req.line) if match: return match.group(1) return None def tostr(self, markers=True): if not self.uri: parts = [self.name or self.line] if self.extras: parts.append('[{0}]'.format(','.join(sorted(self.extras)))) if self.specifier: parts.append(self.format_specs()) elif self.vcs: # leave the line as is, let pip handle it if self.line: return self.line else: # let's build the line manually parts = [ self.uri, '@{}'.format(self.revision) if self.revision else '', '#subdirectory={}'.format(self.subdirectory) if self.subdirectory else '' ] elif self.pip_new_version and self.uri and self.name and self.line and self.local_file: # package @ file:///example.com/somewheel.whl # leave the line as is, let pip handle it return self.line else: parts = [self.uri] if markers and self.marker: parts.append('; {0}'.format(self.marker)) return ''.join(parts) def clone(self): return MarkerRequirement(copy(self.req)) __str__ = tostr def __repr__(self): return '{self.__class__.__name__}[{self}]'.format(self=self) def format_specs(self, num_parts=None, max_num_parts=None): max_num_parts = max_num_parts or num_parts if max_num_parts is None or not self.specs: return ','.join(starmap(operator.add, self.specs)) op, version = self.specs[0] for v in self._sub_versions_pep440: version = version.replace(v, '.') if num_parts: version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts] else: version = version.strip('.').split('.')[:max_num_parts] return op+'.'.join(version) def __getattr__(self, item): return getattr(self.req, item) @property def specs(self): # type: () -> List[Tuple[Text, Text]] return self.req.specs @specs.setter def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None self.req.specs = value def fix_specs(self): def solve_by(func, op_is, specs): return func([(op, version) for op, version in specs if op == op_is]) def solve_equal(specs): if len(set(version for _, version in self.specs)) > 1: raise SpecsResolutionError('more than one "==" spec: {}'.format(specs)) return specs greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs) smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs) equal = solve_by(solve_equal, '==', self.specs) if equal: self.specs = equal else: self.specs = greater + smaller def compare_version(self, requested_version, op=None, num_parts=3): """ compare the requested version with the one we have in the spec, If the requested version is 1.2.3 the self.spec should be 1.2.3* If the requested version is 1.2 the self.spec should be 1.2* etc. :param str requested_version: :param str op: '==', '>', '>=', '<=', '<', '~=' :param int num_parts: number of parts to compare :return: True if we answer the requested version """ # if we have no specific version, we cannot compare, so assume it's okay if not self.specs: return True version = self.specs[0][1] op = (op or self.specs[0][0]).strip() return SimpleVersion.compare_versions( version_a=requested_version, op=op, version_b=version, num_parts=num_parts) class SimpleVersion: _sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ] VERSION_PATTERN = r""" v? (?: (?:(?P<epoch>[0-9]+)!)? # epoch (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment (?P<pre> # pre-release [-_\.]? (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview)) [-_\.]? (?P<pre_n>[0-9]+)? )? (?P<post> # post release (?:-(?P<post_n1>[0-9]+)) | (?: [-_\.]? (?P<post_l>post|rev|r) [-_\.]? (?P<post_n2>[0-9]+)? ) )? (?P<dev> # dev release [-_\.]? (?P<dev_l>dev) [-_\.]? (?P<dev_n>[0-9]+)? )? ) (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version """ _local_version_separators = re.compile(r"[\._-]") _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE) @classmethod def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3): """ Compare two versions based on the op operator returns bool(version_a op version_b) Notice: Ignores a/b/rc/post/dev markers on the version :param str version_a: :param str op: '==', '===', '>', '>=', '<=', '<', '~=' :param str version_b: :param bool ignore_sub_versions: if true compare only major.minor.patch (ignore a/b/rc/post/dev in the comparison) :param int num_parts: number of parts to compare, split by . (dot) :return bool: version_a op version_b """ if not version_b: return True if op == '~=': num_parts = max(num_parts, 2) op = '==' ignore_sub_versions = True elif op == '===': op = '==' try: version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions) version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions) except: # revert to string based for v in cls._sub_versions_pep440: version_a = version_a.replace(v, '.') version_b = version_b.replace(v, '.') version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts] version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts] version_a_key = '' version_b_key = '' for i in range(num_parts): pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])]) version_a_key += pad.format(version_a[i]) version_b_key += pad.format(version_b[i]) if op == '==': return version_a_key == version_b_key if op == '<=': return version_a_key <= version_b_key if op == '>=': return version_a_key >= version_b_key if op == '>': return version_a_key > version_b_key if op == '<': return version_a_key < version_b_key raise ValueError('Unrecognized comparison operator [{}]'.format(op)) @staticmethod def _parse_letter_version( letter, # type: str number, # type: Union[str, bytes, SupportsInt] ): # type: (...) -> Optional[Tuple[str, int]] if letter: # We consider there to be an implicit 0 in a pre-release if there is # not a numeral associated with it. if number is None: number = 0 # We normalize any letters to their lower case form letter = letter.lower() # We consider some words to be alternate spellings of other words and # in those cases we want to normalize the spellings to our preferred # spelling. if letter == "alpha": letter = "a" elif letter == "beta": letter = "b" elif letter in ["c", "pre", "preview"]: letter = "rc" elif letter in ["rev", "r"]: letter = "post" return letter, int(number) if not letter and number: # We assume if we are given a number, but we are not given a letter # then this is using the implicit post release syntax (e.g. 1.0-1) letter = "post" return letter, int(number) return () @staticmethod def _get_match_key(match, num_parts, ignore_sub_versions): if ignore_sub_versions: return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]), (), (), (), (),) return ( int(match.group("epoch")) if match.group("epoch") else 0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]), SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")), SimpleVersion._parse_letter_version( match.group("post_l"), match.group("post_n1") or match.group("post_n2") ), SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")), SimpleVersion._parse_local_version(match.group("local")), ) @staticmethod def _parse_local_version(local): # type: (str) -> Optional[LocalType] """ Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve"). """ if local is not None: return tuple( part.lower() if not part.isdigit() else int(part) for part in SimpleVersion._local_version_separators.split(local) ) return () @six.add_metaclass(ABCMeta) class RequirementSubstitution(object): _pip_extra_index_url = PIP_EXTRA_INDICES def __init__(self, session): # type: (Session) -> () self._session = session self.config = session.config # type: ConfigTree self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config) self.package_manager = self.config['agent.package_manager.type'] @abstractmethod def match(self, req): # type: (MarkerRequirement) -> bool """ Returns whether a requirement needs to be modified by this substitution. """ pass @abstractmethod def replace(self, req): # type: (MarkerRequirement) -> Text """ Replace a requirement """ pass def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement] """ Allows the RequirementSubstitution to add an extra line/requirements after the initial requirements scan is completed. Called only once per requirements.txt object """ return None def post_install(self, session): pass @classmethod def get_pip_version(cls, package): output = Argv( 'pip', 'search', package, *(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url)) ).get_output() # ad-hoc pattern to duplicate the behavior of the old code return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1) @property def cuda_version(self): return self.config['agent.cuda_version'] @property def cudnn_version(self): return self.config['agent.cudnn_version'] class SimpleSubstitution(RequirementSubstitution): @property @abstractmethod def name(self): pass def match(self, req): # type: (MarkerRequirement) -> bool return (self.name == req.name or ( req.uri and re.match(r'https?://', req.uri) and self.name in req.uri )) def replace(self, req): # type: (MarkerRequirement) -> Text """ Replace a requirement :raises: ValueError if version is pre-release """ if req.uri: return re.sub( r'({})(.*?)(-cp)'.format(self.name), r'\1\2{}\3'.format(self.suffix), req.uri, count=1) if req.specs: _, version_number = req.specs[0] # assert packaging_version.parse(version_number) else: version_number = self.get_pip_version(self.name) req.specs = [('==', version_number + self.suffix)] return Text(req) def replace_back(self, list_of_requirements): # type: (Dict) -> Dict """ :param list_of_requirements: {'pip': ['a==1.0', ]} :return: {'pip': ['a==1.0', ]} """ return list_of_requirements @six.add_metaclass(ABCMeta) class CudaSensitiveSubstitution(SimpleSubstitution): def match(self, req): # type: (MarkerRequirement) -> bool return self.cuda_version and self.cudnn_version and \ super(CudaSensitiveSubstitution, self).match(req) class CudaNotFound(Exception): pass class RequirementsManager(object): def __init__(self, session, base_interpreter=None): # type: (Session, PathLike) -> () self._session = session self.config = deepcopy(session.config) # type: ConfigTree self.handlers = [] # type: List[RequirementSubstitution] agent = self.config['agent'] self.active = not agent.get('cpu_only', False) self.found_cuda = False if self.active: try: agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config) self.found_cuda = True except Exception: # if we have a cuda version, it is good enough (we dont have to have cudnn version) if agent.get('cuda_version'): self.found_cuda = True pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / ( 'cu'+agent['cuda_version'] if self.found_cuda else 'cpu') self.translator = RequirementsTranslator(session, interpreter=base_interpreter, cache_dir=pip_cache_dir.as_posix()) def register(self, cls): # type: (Type[RequirementSubstitution]) -> None self.handlers.append(cls(self._session)) def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text] match = re.search(r';\s*(.*)', Text(req)) if match: req.markers = match.group(1).split(',') if not self.active: return None for handler in self.handlers: if handler.match(req): return handler.replace(req) return None def replace(self, requirements): # type: (Text) -> Text def safe_parse(req_str): try: return next(parse(req_str)) except Exception as ex: return Requirement(req_str) parsed_requirements = tuple( map( MarkerRequirement, [safe_parse(line) for line in (requirements.splitlines() if isinstance(requirements, six.text_type) else requirements)] ) ) if not parsed_requirements: # return the original requirements just in case return requirements def replace_one(i, req): # type: (int, MarkerRequirement) -> Optional[Text] try: return self._replace_one(req) except FatalSpecsResolutionError: warning('could not resolve python wheel replacement for {}'.format(req)) raise except Exception: warning('could not resolve python wheel replacement for \"{}\", ' 'using original requirements line: {}'.format(req, i)) return None new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements)) conda = is_conda(self.config) result = map( lambda x, y: (x if x is not None else y.tostr(markers=not conda)), new_requirements, parsed_requirements ) if not conda: result = map(self.translator.translate, result) result = list(result) # add post scan add requirements call back for h in self.handlers: req = h.post_scan_add_req() if req: result.append(req.tostr()) return join_lines(result) def post_install(self, session): for h in self.handlers: try: h.post_install(session) except Exception as ex: print('RequirementsManager handler {} raised exception: {}'.format(h, ex)) raise def replace_back(self, requirements): if self.translator: requirements = self.translator.replace_back(requirements) for h in self.handlers: try: requirements = h.replace_back(requirements) except Exception: pass return requirements @staticmethod def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text) # we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version'] cuda_version = config['agent.cuda_version'] cudnn_version = config['agent.cudnn_version'] if cuda_version and cudnn_version: return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version) if not cuda_version and is_windows_platform(): try: cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys() if k.startswith('CUDA_PATH_V')] cuda_vers = max(cuda_vers) if cuda_vers > 40: cuda_version = cuda_vers except: pass if not cuda_version: try: try: nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc' if is_windows_platform() and 'CUDA_PATH' in os.environ: nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc) output = Argv(nvcc, '--version').get_output() except OSError: raise CudaNotFound('nvcc not found') match = re.search(r'release (.{3})', output).group(1) cuda_version = Text(int(float(match) * 10)) except: pass if not cuda_version: try: try: output = Argv('nvidia-smi',).get_output() except OSError: raise CudaNotFound('nvcc not found') match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output) match = match.group(1)+'.'+match.group(2) cuda_version = Text(int(float(match) * 10)) except: pass if not cudnn_version: try: cuda_lib = which('nvcc') if is_windows_platform: cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']) else: cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])) cudnn_major, cudnn_minor = None, None try: include_file = open(cudnn_h) except OSError: raise CudaNotFound('Could not read cudnn.h') with include_file: for line in include_file: if 'CUDNN_MAJOR' in line: cudnn_major = line.split()[-1] if 'CUDNN_MINOR' in line: cudnn_minor = line.split()[-1] if cudnn_major and cudnn_minor: break cudnn_version = cudnn_major + (cudnn_minor or '0') except: pass return (normalize_cuda_version(cuda_version or 0), normalize_cuda_version(cudnn_version or 0))