mirror of
https://github.com/clearml/clearml-agent
synced 2025-01-31 17:16:51 +00:00
791 lines
29 KiB
Python
791 lines
29 KiB
Python
from __future__ import absolute_import, unicode_literals
|
|
|
|
import operator
|
|
import os
|
|
import re
|
|
from abc import ABCMeta, abstractmethod
|
|
from copy import deepcopy, copy
|
|
from itertools import chain, starmap
|
|
from operator import itemgetter
|
|
from os import path
|
|
from typing import Text, List, Type, Optional, Tuple, Dict
|
|
|
|
from pathlib2 import Path
|
|
from pyhocon import ConfigTree
|
|
|
|
import six
|
|
import logging
|
|
from clearml_agent.definitions import PIP_EXTRA_INDICES
|
|
from clearml_agent.helper.base import (
|
|
warning, is_conda, which, join_lines, is_windows_platform,
|
|
convert_cuda_version_to_int_10_base_str, )
|
|
from clearml_agent.helper.process import Argv, PathLike
|
|
from clearml_agent.helper.gpu.gpustat import get_driver_cuda_version
|
|
from clearml_agent.session import Session, normalize_cuda_version
|
|
from clearml_agent.external.requirements_parser import parse
|
|
from clearml_agent.external.requirements_parser.requirement import Requirement
|
|
|
|
from .translator import RequirementsTranslator
|
|
|
|
|
|
class SpecsResolutionError(Exception):
|
|
pass
|
|
|
|
|
|
class FatalSpecsResolutionError(Exception):
|
|
pass
|
|
|
|
|
|
@six.python_2_unicode_compatible
|
|
class MarkerRequirement(object):
|
|
|
|
# if True pip version above 20.x and with support for "package @ scheme://link"
|
|
# default is True
|
|
pip_new_version = True
|
|
|
|
def __init__(self, req): # type: (Requirement) -> None
|
|
self.req = req
|
|
|
|
@property
|
|
def marker(self):
|
|
match = re.search(r';\s*(.*)', self.req.line)
|
|
if match:
|
|
return match.group(1)
|
|
return None
|
|
|
|
def tostr(self, markers=True):
|
|
if not self.uri:
|
|
parts = [self.name or self.line]
|
|
|
|
if self.extras:
|
|
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
|
|
|
|
if self.specifier:
|
|
parts.append(self.format_specs())
|
|
elif self.vcs:
|
|
# leave the line as is, let pip handle it
|
|
if self.line:
|
|
return self.line
|
|
else:
|
|
# let's build the line manually
|
|
parts = [
|
|
self.uri,
|
|
'@{}'.format(self.revision) if self.revision else '',
|
|
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
|
|
]
|
|
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
|
|
# package @ file:///example.com/somewheel.whl
|
|
# leave the line as is, let pip handle it
|
|
return self.line
|
|
else:
|
|
parts = [self.uri]
|
|
|
|
if markers and self.marker:
|
|
parts.append('; {0}'.format(self.marker))
|
|
|
|
return ''.join(parts)
|
|
|
|
def clone(self):
|
|
return MarkerRequirement(copy(self.req))
|
|
|
|
__str__ = tostr
|
|
|
|
def __repr__(self):
|
|
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
|
|
|
def format_specs(self, num_parts=None, max_num_parts=None):
|
|
max_num_parts = max_num_parts or num_parts
|
|
if max_num_parts is None or not self.specs:
|
|
return ','.join(starmap(operator.add, self.specs))
|
|
|
|
op, version = self.specs[0]
|
|
for v in self._sub_versions_pep440:
|
|
version = version.replace(v, '.')
|
|
if num_parts:
|
|
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
|
else:
|
|
version = version.strip('.').split('.')[:max_num_parts]
|
|
return op+'.'.join(version)
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self.req, item)
|
|
|
|
@property
|
|
def specs(self): # type: () -> List[Tuple[Text, Text]]
|
|
return self.req.specs
|
|
|
|
@specs.setter
|
|
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
|
|
self.req.specs = value
|
|
|
|
def fix_specs(self):
|
|
def solve_by(func, op_is, specs):
|
|
return func([(op, version) for op, version in specs if op == op_is])
|
|
|
|
def solve_equal(specs):
|
|
if len(set(version for _, version in self.specs)) > 1:
|
|
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
|
|
return specs
|
|
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
|
|
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
|
|
equal = solve_by(solve_equal, '==', self.specs)
|
|
if equal:
|
|
self.specs = equal
|
|
else:
|
|
self.specs = greater + smaller
|
|
|
|
def compare_version(self, requested_version, op=None, num_parts=3):
|
|
"""
|
|
compare the requested version with the one we have in the spec,
|
|
If the requested version is 1.2.3 the self.spec should be 1.2.3*
|
|
If the requested version is 1.2 the self.spec should be 1.2*
|
|
etc.
|
|
|
|
:param str requested_version:
|
|
:param str op: '==', '>', '>=', '<=', '<', '~='
|
|
:param int num_parts: number of parts to compare
|
|
:return: True if we answer the requested version
|
|
"""
|
|
# if we have no specific version, we cannot compare, so assume it's okay
|
|
if not self.specs:
|
|
return True
|
|
|
|
version = self.specs[0][1]
|
|
op = (op or self.specs[0][0]).strip()
|
|
|
|
return SimpleVersion.compare_versions(
|
|
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
|
|
|
|
def remove_local_file_ref(self):
|
|
if not self.local_file or self.vcs or self.editable or self.path:
|
|
return False
|
|
parts = re.split(r"@\s*{}".format(self.req.uri), self.req.line)
|
|
# if we did not find anything do nothing
|
|
if len(parts) < 2:
|
|
return False
|
|
self.req.line = ''.join(parts).strip()
|
|
self.req.uri = None
|
|
self.req.local_file = False
|
|
return True
|
|
|
|
def validate_local_file_ref(self):
|
|
# if local file does not exist, remove the reference to it
|
|
if self.vcs or self.editable or self.path or not self.local_file or not self.name or \
|
|
not self.uri or not self.uri.startswith("file://"):
|
|
return
|
|
local_path = Path(self.uri[len("file://"):])
|
|
if not local_path.exists():
|
|
line = self.line
|
|
if self.remove_local_file_ref():
|
|
# print warning
|
|
logging.getLogger(__name__).warning(
|
|
'Local file not found [{}], references removed'.format(line))
|
|
|
|
|
|
class SimpleVersion:
|
|
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
|
|
VERSION_PATTERN = r"""
|
|
v?
|
|
(?:
|
|
(?:(?P<epoch>[0-9]+)!)? # epoch
|
|
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
|
(?P<pre> # pre-release
|
|
[-_\.]?
|
|
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
|
[-_\.]?
|
|
(?P<pre_n>[0-9]+)?
|
|
)?
|
|
(?P<post> # post release
|
|
(?:-(?P<post_n1>[0-9]+))
|
|
|
|
|
(?:
|
|
[-_\.]?
|
|
(?P<post_l>post|rev|r)
|
|
[-_\.]?
|
|
(?P<post_n2>[0-9]+)?
|
|
)
|
|
)?
|
|
(?P<dev> # dev release
|
|
[-_\.]?
|
|
(?P<dev_l>dev)
|
|
[-_\.]?
|
|
(?P<dev_n>[0-9]+)?
|
|
)?
|
|
)
|
|
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
|
"""
|
|
_local_version_separators = re.compile(r"[\._-]")
|
|
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
|
|
|
@classmethod
|
|
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
|
|
"""
|
|
Compare two versions based on the op operator
|
|
returns bool(version_a op version_b)
|
|
Notice: Ignores a/b/rc/post/dev markers on the version
|
|
|
|
:param str version_a:
|
|
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
|
|
:param str version_b:
|
|
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
|
(ignore a/b/rc/post/dev in the comparison)
|
|
:param int num_parts: number of parts to compare, split by . (dot)
|
|
:return bool: version_a op version_b
|
|
"""
|
|
|
|
if not version_b:
|
|
return True
|
|
|
|
if not num_parts:
|
|
num_parts = max(len(version_a.split('.')), len(version_b.split('.')), )
|
|
|
|
if op == '~=':
|
|
num_parts = len(version_b.split('.')) - 1
|
|
num_parts = max(num_parts, 2)
|
|
op = '=='
|
|
ignore_sub_versions = True
|
|
elif op == '===':
|
|
op = '=='
|
|
|
|
try:
|
|
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
|
|
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
|
|
except:
|
|
# revert to string based
|
|
for v in cls._sub_versions_pep440:
|
|
version_a = version_a.replace(v, '.')
|
|
version_b = version_b.replace(v, '.')
|
|
|
|
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
|
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
|
version_a_key = ''
|
|
version_b_key = ''
|
|
for i in range(num_parts):
|
|
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
|
|
version_a_key += pad.format(version_a[i])
|
|
version_b_key += pad.format(version_b[i])
|
|
|
|
if op == '==':
|
|
return version_a_key == version_b_key
|
|
if op == '<=':
|
|
return version_a_key <= version_b_key
|
|
if op == '>=':
|
|
return version_a_key >= version_b_key
|
|
if op == '>':
|
|
return version_a_key > version_b_key
|
|
if op == '<':
|
|
return version_a_key < version_b_key
|
|
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
|
|
|
|
@classmethod
|
|
def max_version(cls, version_a, version_b):
|
|
return version_a if cls.compare_versions(
|
|
version_a=version_a, op='>=', version_b=version_b, num_parts=None) else version_b
|
|
|
|
@classmethod
|
|
def min_version(cls, version_a, version_b):
|
|
return version_a if cls.compare_versions(
|
|
version_a=version_a, op='<=', version_b=version_b, num_parts=None) else version_b
|
|
|
|
@staticmethod
|
|
def _parse_letter_version(
|
|
letter, # type: str
|
|
number, # type: Union[str, bytes, SupportsInt]
|
|
):
|
|
# type: (...) -> Optional[Tuple[str, int]]
|
|
|
|
if letter:
|
|
# We consider there to be an implicit 0 in a pre-release if there is
|
|
# not a numeral associated with it.
|
|
if number is None:
|
|
number = 0
|
|
|
|
# We normalize any letters to their lower case form
|
|
letter = letter.lower()
|
|
|
|
# We consider some words to be alternate spellings of other words and
|
|
# in those cases we want to normalize the spellings to our preferred
|
|
# spelling.
|
|
if letter == "alpha":
|
|
letter = "a"
|
|
elif letter == "beta":
|
|
letter = "b"
|
|
elif letter in ["c", "pre", "preview"]:
|
|
letter = "rc"
|
|
elif letter in ["rev", "r"]:
|
|
letter = "post"
|
|
|
|
return letter, int(number)
|
|
if not letter and number:
|
|
# We assume if we are given a number, but we are not given a letter
|
|
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
|
letter = "post"
|
|
|
|
return letter, int(number)
|
|
|
|
return ()
|
|
|
|
@staticmethod
|
|
def _get_match_key(match, num_parts, ignore_sub_versions):
|
|
if ignore_sub_versions:
|
|
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
|
(), (), (), (),)
|
|
return (
|
|
int(match.group("epoch")) if match.group("epoch") else 0,
|
|
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
|
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
|
SimpleVersion._parse_letter_version(
|
|
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
|
),
|
|
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
|
SimpleVersion._parse_local_version(match.group("local")),
|
|
)
|
|
|
|
@staticmethod
|
|
def _parse_local_version(local):
|
|
# type: (str) -> Optional[LocalType]
|
|
"""
|
|
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
|
"""
|
|
if local is not None:
|
|
return tuple(
|
|
part.lower() if not part.isdigit() else int(part)
|
|
for part in SimpleVersion._local_version_separators.split(local)
|
|
)
|
|
return ()
|
|
|
|
|
|
def compare_version_rules(specs_a, specs_b):
|
|
# specs_a/b are a list of tuples: [('==', '1.2.3'), ] or [('>=', '1.2'), ('<', '1.3')]
|
|
# section definition:
|
|
class Section(object):
|
|
def __init__(self, left=None, left_eq=False, right=None, right_eq=False):
|
|
self.left, self.left_eq, self.right, self.right_eq = left, left_eq, right, right_eq
|
|
# first create a list of in/out sections for each spec
|
|
# >, >= are left rule
|
|
# <, <= are right rule
|
|
# ~= x.y.z is converted to: >= x.y and < x.y+1
|
|
# ==/=== are converted to: >= and <=
|
|
# != x.y.z will split a section into: left < x.y.z and right > x.y.z
|
|
def create_section(specs):
|
|
section = Section()
|
|
for op, v in specs:
|
|
a = section
|
|
if op == '>':
|
|
a.left = v
|
|
a.left_eq = False
|
|
elif op == '>=':
|
|
a.left = v
|
|
a.left_eq = True
|
|
elif op == '<':
|
|
a.right = v
|
|
a.right_eq = False
|
|
elif op == '<=':
|
|
a.right = v
|
|
a.right_eq = True
|
|
elif op == '==':
|
|
a.left = v
|
|
a.left_eq = True
|
|
a.right = v
|
|
a.right_eq = True
|
|
elif op == '~=':
|
|
new_v = v.split('.')
|
|
a_left = '.'.join(new_v[:-1])
|
|
a.left = a_left if not a.left else SimpleVersion.max_version(a_left, a.left)
|
|
a.left_eq = True
|
|
a_right = '.'.join(new_v[:-2] + [str(int(new_v[-2])+1)])
|
|
a.right = a_right if not a.right else SimpleVersion.min_version(a_right, a.right)
|
|
a.right_eq = False if a.right == a_right else a.right_eq
|
|
|
|
return section
|
|
|
|
section_a = create_section(specs_a)
|
|
section_b = create_section(specs_b)
|
|
i = Section()
|
|
# then we have a list of sections for spec A/B
|
|
if section_a.left == section_b.left:
|
|
i.left = section_a.left
|
|
i.left_eq = section_a.left_eq and section_b.left_eq
|
|
else:
|
|
i.left = SimpleVersion.max_version(section_a.left, section_b.left)
|
|
i.left_eq = section_a.left_eq if i.left == section_a.left else section_b.left_eq
|
|
if section_a.right == section_b.right:
|
|
i.right = section_a.right
|
|
i.right_eq = section_a.right_eq and section_b.right_eq
|
|
else:
|
|
i.right = SimpleVersion.min_version(section_a.right, section_b.right)
|
|
i.right_eq = section_a.right_eq if i.right == section_a.right else section_b.right_eq
|
|
|
|
# return true if any section from A intersects a section from B
|
|
valid = True
|
|
valid &= SimpleVersion.compare_versions(
|
|
version_a=i.left, op='<=' if i.left_eq else '<', version_b=i.right, num_parts=None)
|
|
valid &= SimpleVersion.compare_versions(
|
|
version_a=i.right, op='>=' if i.left_eq else '>', version_b=i.left, num_parts=None)
|
|
|
|
return valid
|
|
|
|
|
|
@six.add_metaclass(ABCMeta)
|
|
class RequirementSubstitution(object):
|
|
|
|
_pip_extra_index_url = PIP_EXTRA_INDICES
|
|
|
|
def __init__(self, session):
|
|
# type: (Session) -> ()
|
|
self._session = session
|
|
self.config = session.config # type: ConfigTree
|
|
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
|
|
self.package_manager = self.config['agent.package_manager.type']
|
|
self._is_already_installed_cb = None
|
|
|
|
@abstractmethod
|
|
def match(self, req): # type: (MarkerRequirement) -> bool
|
|
"""
|
|
Returns whether a requirement needs to be modified by this substitution.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def replace(self, req): # type: (MarkerRequirement) -> Text
|
|
"""
|
|
Replace a requirement
|
|
"""
|
|
pass
|
|
|
|
def set_is_already_installed_cb(self, cb):
|
|
self._is_already_installed_cb = cb
|
|
|
|
def is_already_installed(self, req):
|
|
if not self._is_already_installed_cb:
|
|
return False
|
|
# noinspection PyBroadException
|
|
try:
|
|
return self._is_already_installed_cb(req)
|
|
except BaseException as ex:
|
|
# debug could not resolve something
|
|
print("Warning: Requirements post install callback exception (check if package installed): {}".format(ex))
|
|
return False
|
|
|
|
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
|
"""
|
|
Allows the RequirementSubstitution to add an extra line/requirements after
|
|
the initial requirements scan is completed.
|
|
Called only once per requirements.txt object
|
|
"""
|
|
return None
|
|
|
|
def post_install(self, session):
|
|
pass
|
|
|
|
@classmethod
|
|
def get_pip_version(cls, package):
|
|
output = Argv(
|
|
'pip',
|
|
'search',
|
|
package,
|
|
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
|
|
).get_output()
|
|
# ad-hoc pattern to duplicate the behavior of the old code
|
|
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
|
|
|
|
@property
|
|
def cuda_version(self):
|
|
return convert_cuda_version_to_int_10_base_str(self.config['agent.cuda_version'])
|
|
|
|
@property
|
|
def cudnn_version(self):
|
|
return self.config['agent.cudnn_version']
|
|
|
|
|
|
class SimpleSubstitution(RequirementSubstitution):
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self):
|
|
pass
|
|
|
|
def match(self, req): # type: (MarkerRequirement) -> bool
|
|
return (self.name == req.name or (
|
|
req.uri and
|
|
re.match(r'https?://', req.uri) and
|
|
self.name in req.uri
|
|
))
|
|
|
|
def replace(self, req): # type: (MarkerRequirement) -> Text
|
|
"""
|
|
Replace a requirement
|
|
:raises: ValueError if version is pre-release
|
|
"""
|
|
if req.uri:
|
|
return re.sub(
|
|
r'({})(.*?)(-cp)'.format(self.name),
|
|
r'\1\2{}\3'.format(self.suffix),
|
|
req.uri,
|
|
count=1)
|
|
|
|
if req.specs:
|
|
_, version_number = req.specs[0]
|
|
# assert packaging_version.parse(version_number)
|
|
else:
|
|
version_number = self.get_pip_version(self.name)
|
|
|
|
req.specs = [('==', version_number + self.suffix)]
|
|
return Text(req)
|
|
|
|
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
|
|
"""
|
|
:param list_of_requirements: {'pip': ['a==1.0', ]}
|
|
:return: {'pip': ['a==1.0', ]}
|
|
"""
|
|
return list_of_requirements
|
|
|
|
|
|
@six.add_metaclass(ABCMeta)
|
|
class CudaSensitiveSubstitution(SimpleSubstitution):
|
|
|
|
def match(self, req): # type: (MarkerRequirement) -> bool
|
|
return self.cuda_version and self.cudnn_version and \
|
|
super(CudaSensitiveSubstitution, self).match(req)
|
|
|
|
|
|
class CudaNotFound(Exception):
|
|
pass
|
|
|
|
|
|
class RequirementsManager(object):
|
|
|
|
def __init__(self, session, base_interpreter=None):
|
|
# type: (Session, PathLike) -> ()
|
|
self._session = session
|
|
self.config = deepcopy(session.config) # type: ConfigTree
|
|
self.handlers = [] # type: List[RequirementSubstitution]
|
|
agent = self.config['agent']
|
|
self.active = not agent.get('cpu_only', False)
|
|
self.found_cuda = False
|
|
if self.active:
|
|
try:
|
|
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
|
|
self.found_cuda = True
|
|
except Exception:
|
|
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
|
|
if agent.get('cuda_version'):
|
|
self.found_cuda = True
|
|
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
|
|
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
|
|
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
|
|
cache_dir=pip_cache_dir.as_posix())
|
|
self._base_interpreter = base_interpreter
|
|
self._cwd = None
|
|
self._installed_parsed_packages = set()
|
|
|
|
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
|
|
self.handlers.append(cls(self._session))
|
|
|
|
def set_cwd(self, cwd):
|
|
self._cwd = str(cwd) if cwd else None
|
|
|
|
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
|
|
match = re.search(r';\s*(.*)', Text(req))
|
|
if match:
|
|
req.markers = match.group(1).split(',')
|
|
if not self.active:
|
|
return None
|
|
for handler in self.handlers:
|
|
if handler.match(req):
|
|
return handler.replace(req)
|
|
return None
|
|
|
|
def replace(self, requirements): # type: (Text) -> Text
|
|
parsed_requirements = self.parse_requirements_section_to_marker_requirements(
|
|
requirements=requirements, cwd=self._cwd)
|
|
|
|
if not parsed_requirements:
|
|
# return the original requirements just in case
|
|
return requirements
|
|
|
|
def replace_one(i, req):
|
|
# type: (int, MarkerRequirement) -> Optional[Text]
|
|
try:
|
|
return self._replace_one(req)
|
|
except FatalSpecsResolutionError:
|
|
warning('could not resolve python wheel replacement for {}'.format(req))
|
|
raise
|
|
except Exception:
|
|
warning('could not resolve python wheel replacement for \"{}\", '
|
|
'using original requirements line: {}'.format(req, i))
|
|
return None
|
|
|
|
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
|
|
conda = is_conda(self.config)
|
|
result = map(
|
|
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
|
|
new_requirements,
|
|
parsed_requirements
|
|
)
|
|
if not conda:
|
|
result = map(self.translator.translate, result)
|
|
|
|
result = list(result)
|
|
# add post scan add requirements call back
|
|
for h in self.handlers:
|
|
req = h.post_scan_add_req()
|
|
if req:
|
|
result.append(req.tostr())
|
|
|
|
return join_lines(result)
|
|
|
|
def post_install(self, session, package_manager=None):
|
|
if package_manager:
|
|
self.update_installed_packages_state(package_manager.freeze())
|
|
for h in self.handlers:
|
|
try:
|
|
h.post_install(session)
|
|
except Exception as ex:
|
|
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
|
|
raise
|
|
|
|
def replace_back(self, requirements):
|
|
if self.translator:
|
|
requirements = self.translator.replace_back(requirements)
|
|
|
|
for h in self.handlers:
|
|
try:
|
|
requirements = h.replace_back(requirements)
|
|
except Exception:
|
|
pass
|
|
return requirements
|
|
|
|
def get_interpreter(self):
|
|
return self._base_interpreter
|
|
|
|
def update_installed_packages_state(self, requirements):
|
|
"""
|
|
Updates internal Installed Packages objects, so that later we can detect
|
|
if we already have a pre-installed package
|
|
:param requirements: is the output of a freeze() call, i.e. dict {'pip': "package==version"}
|
|
"""
|
|
requirements = requirements if not isinstance(requirements, dict) else requirements.get("pip")
|
|
self._installed_parsed_packages = self.parse_requirements_section_to_marker_requirements(
|
|
requirements=requirements, cwd=self._cwd)
|
|
for h in self.handlers:
|
|
h.set_is_already_installed_cb(self._callback_is_already_installed)
|
|
|
|
def _callback_is_already_installed(self, req):
|
|
for p in (self._installed_parsed_packages or []):
|
|
if p.name != req.name:
|
|
continue
|
|
# if this is version control package, only return true of both installed and requests specify commit ID
|
|
if req.vcs:
|
|
return p.vcs and req.revision and req.revision == p.revision
|
|
|
|
if not req.specs and not p.specs:
|
|
return True
|
|
|
|
# return if this is the same version
|
|
return req.specs and p.specs and req.compare_version(p, op="==")
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
|
|
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
|
|
cuda_version = config['agent.cuda_version']
|
|
cudnn_version = config['agent.cudnn_version']
|
|
if cuda_version and cudnn_version:
|
|
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
|
|
|
|
if not cuda_version:
|
|
cuda_version = get_driver_cuda_version()
|
|
|
|
if not cuda_version and is_windows_platform():
|
|
try:
|
|
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
|
|
if k.startswith('CUDA_PATH_V')]
|
|
cuda_vers = max(cuda_vers)
|
|
if cuda_vers > 40:
|
|
cuda_version = cuda_vers
|
|
except:
|
|
pass
|
|
|
|
if not cuda_version:
|
|
try:
|
|
try:
|
|
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
|
|
if is_windows_platform() and 'CUDA_PATH' in os.environ:
|
|
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
|
|
|
|
output = Argv(nvcc, '--version').get_output()
|
|
except OSError:
|
|
raise CudaNotFound('nvcc not found')
|
|
match = re.search(r'release (.{3})', output).group(1)
|
|
cuda_version = Text(int(float(match) * 10))
|
|
except:
|
|
pass
|
|
|
|
if not cuda_version:
|
|
try:
|
|
try:
|
|
output = Argv('nvidia-smi',).get_output()
|
|
except OSError:
|
|
raise CudaNotFound('nvcc not found')
|
|
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
|
|
match = match.group(1)+'.'+match.group(2)
|
|
cuda_version = Text(int(float(match) * 10))
|
|
except:
|
|
pass
|
|
|
|
if not cudnn_version:
|
|
try:
|
|
cuda_lib = which('nvcc')
|
|
if is_windows_platform:
|
|
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
|
|
else:
|
|
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
|
|
|
|
cudnn_major, cudnn_minor = None, None
|
|
try:
|
|
include_file = open(cudnn_h)
|
|
except OSError:
|
|
raise CudaNotFound('Could not read cudnn.h')
|
|
with include_file:
|
|
for line in include_file:
|
|
if 'CUDNN_MAJOR' in line:
|
|
cudnn_major = line.split()[-1]
|
|
if 'CUDNN_MINOR' in line:
|
|
cudnn_minor = line.split()[-1]
|
|
if cudnn_major and cudnn_minor:
|
|
break
|
|
cudnn_version = cudnn_major + (cudnn_minor or '0')
|
|
except:
|
|
pass
|
|
|
|
return (normalize_cuda_version(cuda_version or 0),
|
|
normalize_cuda_version(cudnn_version or 0))
|
|
|
|
@staticmethod
|
|
def parse_requirements_section_to_marker_requirements(requirements, cwd=None):
|
|
def safe_parse(req_str):
|
|
# noinspection PyBroadException
|
|
try:
|
|
return list(parse(req_str, cwd=cwd))
|
|
except Exception as ex:
|
|
return [Requirement(req_str)]
|
|
|
|
def create_req(x):
|
|
r = MarkerRequirement(x)
|
|
r.validate_local_file_ref()
|
|
return r
|
|
|
|
if not requirements:
|
|
return tuple()
|
|
|
|
parsed_requirements = tuple(
|
|
map(
|
|
create_req,
|
|
[r for line in (requirements.splitlines() if isinstance(requirements, str) else requirements)
|
|
for r in safe_parse(line)]
|
|
)
|
|
)
|
|
return parsed_requirements
|