mirror of
synced 2025-02-12 07:38:04 +00:00
815 lines
30 KiB
815 lines
30 KiB
from __future__ import absolute_import, unicode_literals
import operator
import os
import re
from abc import ABCMeta, abstractmethod
from copy import deepcopy, copy
from itertools import chain, starmap
from operator import itemgetter
from os import path
from typing import Text, List, Type, Optional, Tuple, Dict
from pathlib2 import Path
from clearml_agent.external.pyhocon import ConfigTree
import six
from six.moves.urllib.parse import unquote
import logging
from clearml_agent.definitions import PIP_EXTRA_INDICES
from clearml_agent.helper.base import (
warning, is_conda, which, join_lines, is_windows_platform,
convert_cuda_version_to_int_10_base_str, )
from clearml_agent.helper.process import Argv, PathLike
from clearml_agent.helper.gpu.gpustat import get_driver_cuda_version
from clearml_agent.session import Session, normalize_cuda_version
from clearml_agent.external.requirements_parser import parse
from clearml_agent.external.requirements_parser.requirement import Requirement
from .translator import RequirementsTranslator
class SpecsResolutionError(Exception):
class FatalSpecsResolutionError(Exception):
class MarkerRequirement(object):
# if True pip version above 20.x and with support for "package @ scheme://link"
# default is True
pip_new_version = True
def __init__(self, req): # type: (Requirement) -> None
self.req = req
def marker(self):
match = re.search(r';\s*(.*)', self.req.line)
if match:
return match.group(1)
return None
def tostr(self, markers=True):
if not self.uri:
parts = [self.name or self.line]
if self.extras:
if self.specifier:
elif self.vcs:
# leave the line as is, let pip handle it
if self.line:
return self.line
# let's build the line manually
parts = [
'@{}'.format(self.revision) if self.revision else '',
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
# package @ file:///example.com/somewheel.whl
# leave the line as is, let pip handle it
return self.line
parts = [self.uri]
if markers and self.marker:
parts.append('; {0}'.format(self.marker))
return ''.join(parts)
def clone(self):
return MarkerRequirement(copy(self.req))
__str__ = tostr
def __repr__(self):
return '{self.__class__.__name__}[{self}]'.format(self=self)
def format_specs(self, num_parts=None, max_num_parts=None):
max_num_parts = max_num_parts or num_parts
if max_num_parts is None or not self.specs:
return ','.join(starmap(operator.add, self.specs))
op, version = self.specs[0]
# noinspection PyProtectedMember
for v in SimpleVersion._sub_versions_pep440:
version = version.replace(v, '.')
if num_parts:
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
version = version.strip('.').split('.')[:max_num_parts]
return op+'.'.join(version)
def __getattr__(self, item):
return getattr(self.req, item)
def specs(self): # type: () -> List[Tuple[Text, Text]]
return self.req.specs
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
self.req.specs = value
def fix_specs(self):
def solve_by(func, op_is, specs):
return func([(op, version) for op, version in specs if op == op_is])
def solve_equal(specs):
if len(set(version for _, version in self.specs)) > 1:
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
return specs
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
equal = solve_by(solve_equal, '==', self.specs)
if equal:
self.specs = equal
self.specs = greater + smaller
def compare_version(self, requested_version, op=None, num_parts=3):
compare the requested version with the one we have in the spec,
If the requested version is 1.2.3 the self.spec should be 1.2.3*
If the requested version is 1.2 the self.spec should be 1.2*
:param str requested_version:
:param str op: '==', '>', '>=', '<=', '<', '~='
:param int num_parts: number of parts to compare
:return: True if we answer the requested version
# if we have no specific version, we cannot compare, so assume it's okay
if not self.specs:
return True
version = self.specs[0][1]
op = (op or self.specs[0][0]).strip()
return SimpleVersion.compare_versions(
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
def remove_local_file_ref(self):
if not self.local_file or self.vcs or self.editable or self.path:
return False
parts = re.split(r"@\s*{}".format(self.req.uri), self.req.line)
# if we did not find anything do nothing
if len(parts) < 2:
return False
self.req.line = ''.join(parts).strip()
self.req.uri = None
self.req.local_file = False
return True
def validate_local_file_ref(self):
# if local file does not exist, remove the reference to it
if self.vcs or self.editable or self.path or not self.local_file or not self.name or \
not self.uri or not self.uri.startswith("file://"):
local_path = Path(self.uri[len("file://"):])
if not local_path.exists():
local_path = Path(unquote(self.uri)[len("file://"):])
if not local_path.exists():
line = self.line
if self.remove_local_file_ref():
# print warning
'Local file not found [{}], references removed'.format(line))
class SimpleVersion:
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
(?P<post> # post release
(?P<dev> # dev release
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
_local_version_separators = re.compile(r"[\._-]")
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
Compare two versions based on the op operator
returns bool(version_a op version_b)
Notice: Ignores a/b/rc/post/dev markers on the version
:param str version_a:
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
:param str version_b:
:param bool ignore_sub_versions: if true compare only major.minor.patch
(ignore a/b/rc/post/dev in the comparison)
:param int num_parts: number of parts to compare, split by . (dot)
:return bool: version_a op version_b
if not version_b:
return True
if not num_parts:
num_parts = max(len(version_a.split('.')), len(version_b.split('.')), )
if op == '~=':
num_parts = len(version_b.split('.')) - 1
num_parts = max(num_parts, 2)
op = '=='
ignore_sub_versions = True
elif op == '===':
op = '=='
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
# revert to string based
for v in cls._sub_versions_pep440:
version_a = version_a.replace(v, '.')
version_b = version_b.replace(v, '.')
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_a_key = ''
version_b_key = ''
for i in range(num_parts):
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
version_a_key += pad.format(version_a[i])
version_b_key += pad.format(version_b[i])
if op == '==':
return version_a_key == version_b_key
if op == '<=':
return version_a_key <= version_b_key
if op == '>=':
return version_a_key >= version_b_key
if op == '>':
return version_a_key > version_b_key
if op == '<':
return version_a_key < version_b_key
if op == '!=':
return version_a_key != version_b_key
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
def max_version(cls, version_a, version_b):
return version_a if cls.compare_versions(
version_a=version_a, op='>=', version_b=version_b, num_parts=None) else version_b
def min_version(cls, version_a, version_b):
return version_a if cls.compare_versions(
version_a=version_a, op='<=', version_b=version_b, num_parts=None) else version_b
def _parse_letter_version(
letter, # type: str
number, # type: Union[str, bytes, SupportsInt]
# type: (...) -> Optional[Tuple[str, int]]
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
return ()
def _get_match_key(match, num_parts, ignore_sub_versions):
if ignore_sub_versions:
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
(), (), (), (),)
return (
int(match.group("epoch")) if match.group("epoch") else 0,
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
def _parse_local_version(local):
# type: (str) -> Optional[LocalType]
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in SimpleVersion._local_version_separators.split(local)
return ()
def compare_version_rules(specs_a, specs_b):
# specs_a/b are a list of tuples: [('==', '1.2.3'), ] or [('>=', '1.2'), ('<', '1.3')]
# section definition:
class Section(object):
def __init__(self, left="-999999999", left_eq=False, right="999999999", right_eq=False):
self.left, self.left_eq, self.right, self.right_eq = left, left_eq, right, right_eq
# first create a list of in/out sections for each spec
# >, >= are left rule
# <, <= are right rule
# ~= x.y.z is converted to: >= x.y and < x.y+1
# ==/=== are converted to: >= and <=
# != x.y.z will split a section into: left < x.y.z and right > x.y.z
def create_section(specs):
section = Section()
for op, v in specs:
a = section
if op == '>':
a.left = v
a.left_eq = False
elif op == '>=':
a.left = v
a.left_eq = True
elif op == '<':
a.right = v
a.right_eq = False
elif op == '<=':
a.right = v
a.right_eq = True
elif op == '==':
a.left = v
a.left_eq = True
a.right = v
a.right_eq = True
elif op == '~=':
new_v = v.split('.')
a_left = '.'.join(new_v[:-1])
a.left = a_left if not a.left else SimpleVersion.max_version(a_left, a.left)
a.left_eq = True
a_right = '.'.join(new_v[:-2] + [str(int(new_v[-2])+1)])
a.right = a_right if not a.right else SimpleVersion.min_version(a_right, a.right)
a.right_eq = False if a.right == a_right else a.right_eq
return section
section_a = create_section(specs_a)
section_b = create_section(specs_b)
i = Section()
# then we have a list of sections for spec A/B
if section_a.left == section_b.left:
i.left = section_a.left
i.left_eq = section_a.left_eq and section_b.left_eq
i.left = SimpleVersion.max_version(section_a.left, section_b.left)
i.left_eq = section_a.left_eq if i.left == section_a.left else section_b.left_eq
if section_a.right == section_b.right:
i.right = section_a.right
i.right_eq = section_a.right_eq and section_b.right_eq
i.right = SimpleVersion.min_version(section_a.right, section_b.right)
i.right_eq = section_a.right_eq if i.right == section_a.right else section_b.right_eq
# return true if any section from A intersects a section from B
valid = True
valid &= SimpleVersion.compare_versions(
version_a=i.left, op='<=' if i.left_eq else '<', version_b=i.right, num_parts=None)
valid &= SimpleVersion.compare_versions(
version_a=i.right, op='>=' if i.left_eq else '>', version_b=i.left, num_parts=None)
return valid
class RequirementSubstitution(object):
_pip_extra_index_url = PIP_EXTRA_INDICES
def set_add_install_extra_index(cls, extra_index_url):
if extra_index_url not in cls._pip_extra_index_url:
def __init__(self, session):
# type: (Session) -> ()
self._session = session
self.config = session.config # type: ConfigTree
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
self.package_manager = self.config['agent.package_manager.type']
self._is_already_installed_cb = None
def match(self, req): # type: (MarkerRequirement) -> bool
Returns whether a requirement needs to be modified by this substitution.
def replace(self, req): # type: (MarkerRequirement) -> Text
Replace a requirement
def set_is_already_installed_cb(self, cb):
self._is_already_installed_cb = cb
def is_already_installed(self, req):
if not self._is_already_installed_cb:
return False
# noinspection PyBroadException
return self._is_already_installed_cb(req)
except BaseException as ex:
# debug could not resolve something
print("Warning: Requirements post install callback exception (check if package installed): {}".format(ex))
return False
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
Allows the RequirementSubstitution to add an extra line/requirements after
the initial requirements scan is completed.
Called only once per requirements.txt object
return None
def post_install(self, session):
def get_pip_version(cls, package):
output = Argv(
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
# ad-hoc pattern to duplicate the behavior of the old code
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
def cuda_version(self):
return convert_cuda_version_to_int_10_base_str(self.config['agent.cuda_version'])
def cudnn_version(self):
return self.config['agent.cudnn_version']
class SimpleSubstitution(RequirementSubstitution):
def name(self):
def match(self, req): # type: (MarkerRequirement) -> bool
return (self.name == req.name or (
req.uri and
re.match(r'https?://', req.uri) and
self.name in req.uri
def replace(self, req): # type: (MarkerRequirement) -> Text
Replace a requirement
:raises: ValueError if version is pre-release
if req.uri:
return re.sub(
if req.specs:
_, version_number = req.specs[0]
# assert packaging_version.parse(version_number)
version_number = self.get_pip_version(self.name)
req.specs = [('==', version_number + self.suffix)]
return Text(req)
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
:param list_of_requirements: {'pip': ['a==1.0', ]}
:return: {'pip': ['a==1.0', ]}
return list_of_requirements
class CudaSensitiveSubstitution(SimpleSubstitution):
def match(self, req): # type: (MarkerRequirement) -> bool
return self.cuda_version and self.cudnn_version and \
super(CudaSensitiveSubstitution, self).match(req)
class CudaNotFound(Exception):
class RequirementsManager(object):
def __init__(self, session, base_interpreter=None):
# type: (Session, PathLike) -> ()
self._session = session
self.config = deepcopy(session.config) # type: ConfigTree
self.handlers = [] # type: List[RequirementSubstitution]
agent = self.config['agent']
self.active = not agent.get('cpu_only', False)
self.found_cuda = False
if self.active:
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
self.found_cuda = True
except Exception:
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
if agent.get('cuda_version'):
self.found_cuda = True
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
self._base_interpreter = base_interpreter
self._cwd = None
self._installed_parsed_packages = set()
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
def set_cwd(self, cwd):
self._cwd = str(cwd) if cwd else None
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
match = re.search(r';\s*(.*)', Text(req))
if match:
req.markers = match.group(1).split(',')
if not self.active:
return None
for handler in self.handlers:
if handler.match(req):
return handler.replace(req)
return None
def replace(self, requirements): # type: (Text) -> Text
parsed_requirements = self.parse_requirements_section_to_marker_requirements(
requirements=requirements, cwd=self._cwd)
if not parsed_requirements:
# return the original requirements just in case
return requirements
def replace_one(i, req):
# type: (int, MarkerRequirement) -> Optional[Text]
return self._replace_one(req)
except FatalSpecsResolutionError:
warning('could not resolve python wheel replacement for {}'.format(req))
except Exception:
warning('could not resolve python wheel replacement for \"{}\", '
'using original requirements line: {}'.format(req, i))
return None
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
conda = is_conda(self.config)
result = map(
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
if not conda:
result = map(self.translator.translate, result)
result = list(result)
# add post scan add requirements call back
double_req_set = None
for h in self.handlers:
reqs = h.post_scan_add_req()
if reqs:
if double_req_set is None:
def safe_parse_name(line):
return Requirement.parse(line).name
except: # noqa
return None
double_req_set = set([safe_parse_name(r) for r in result if r])
for r in (reqs if isinstance(reqs, (tuple, list)) else [reqs]):
if r and (not r.name or r.name not in double_req_set):
elif r:
print("SKIPPING additional auto installed package: \"{}\"".format(r))
return join_lines(result)
def post_install(self, session, package_manager=None):
if package_manager:
for h in self.handlers:
except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
def replace_back(self, requirements):
if self.translator:
requirements = self.translator.replace_back(requirements)
for h in self.handlers:
requirements = h.replace_back(requirements)
except Exception:
return requirements
def get_interpreter(self):
return self._base_interpreter
def update_installed_packages_state(self, requirements):
Updates internal Installed Packages objects, so that later we can detect
if we already have a pre-installed package
:param requirements: is the output of a freeze() call, i.e. dict {'pip': "package==version"}
requirements = requirements if not isinstance(requirements, dict) else requirements.get("pip")
self._installed_parsed_packages = self.parse_requirements_section_to_marker_requirements(
requirements=requirements, cwd=self._cwd)
for h in self.handlers:
def _callback_is_already_installed(self, req):
for p in (self._installed_parsed_packages or []):
if p.name != req.name:
# if this is version control package, only return true of both installed and requests specify commit ID
if req.vcs:
return p.vcs and req.revision and req.revision == p.revision
if not req.specs and not p.specs:
return True
# return if this is the same version
return req.specs and p.specs and req.compare_version(p, op="==")
return False
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
cuda_version = config['agent.cuda_version']
cudnn_version = config['agent.cudnn_version']
if cuda_version and cudnn_version:
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
if not cuda_version:
cuda_version = get_driver_cuda_version()
if not cuda_version and is_windows_platform():
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
if k.startswith('CUDA_PATH_V')]
cuda_vers = max(cuda_vers)
if cuda_vers > 40:
cuda_version = cuda_vers
if not cuda_version:
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
if is_windows_platform() and 'CUDA_PATH' in os.environ:
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
output = Argv(nvcc, '--version').get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'release (.{3})', output).group(1)
cuda_version = Text(int(float(match) * 10))
if not cuda_version:
output = Argv('nvidia-smi',).get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
match = match.group(1)+'.'+match.group(2)
cuda_version = Text(int(float(match) * 10))
if not cudnn_version:
cuda_lib = which('nvcc')
if is_windows_platform:
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
cudnn_major, cudnn_minor = None, None
include_file = open(cudnn_h)
except OSError:
raise CudaNotFound('Could not read cudnn.h')
with include_file:
for line in include_file:
if 'CUDNN_MAJOR' in line:
cudnn_major = line.split()[-1]
if 'CUDNN_MINOR' in line:
cudnn_minor = line.split()[-1]
if cudnn_major and cudnn_minor:
cudnn_version = cudnn_major + (cudnn_minor or '0')
return (normalize_cuda_version(cuda_version or 0),
normalize_cuda_version(cudnn_version or 0))
def parse_requirements_section_to_marker_requirements(requirements, cwd=None):
def safe_parse(req_str):
# noinspection PyBroadException
return list(parse(req_str, cwd=cwd))
except Exception as ex:
return [Requirement(req_str)]
def create_req(x):
r = MarkerRequirement(x)
return r
if not requirements:
return tuple()
parsed_requirements = tuple(
[r for line in (requirements.splitlines() if isinstance(requirements, str) else requirements)
for r in safe_parse(line)]
return parsed_requirements