Fix PyTorch support to ignore minor versions when looking for package to install or to download

This commit is contained in:
allegroai 2020-03-20 10:48:48 +02:00
parent 98a983d9a2
commit 5ef627165c
5 changed files with 210 additions and 27 deletions

View File

@ -5,7 +5,6 @@ future>=0.16.0
humanfriendly>=2.1 humanfriendly>=2.1
jsonmodels>=2.2 jsonmodels>=2.2
jsonschema>=2.6.0 jsonschema>=2.6.0
packaging>=16.0
pathlib2>=2.3.0 pathlib2>=2.3.0
psutil>=3.4.2 psutil>=3.4.2
pyhocon>=0.3.38 pyhocon>=0.3.38

View File

@ -4,7 +4,7 @@ from time import sleep
import requests import requests
import json import json
from threading import Thread from threading import Thread
from packaging import version as packaging_version from .package.requirements import SimpleVersion
from ..version import __version__ from ..version import __version__
__check_update_thread = None __check_update_thread = None
@ -30,11 +30,11 @@ def _check_new_version_available():
return None return None
trains_answer = update_server_releases.get("trains-agent", {}) trains_answer = update_server_releases.get("trains-agent", {})
latest_version = trains_answer.get("version") latest_version = trains_answer.get("version")
cur_version = packaging_version.parse(cur_version) cur_version = cur_version
latest_version = packaging_version.parse(latest_version or '') latest_version = latest_version or ''
if cur_version >= latest_version: if SimpleVersion.compare_versions(cur_version, '>=', latest_version):
return None return None
patch_upgrade = latest_version.major == cur_version.major and latest_version.minor == cur_version.minor patch_upgrade = True # latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n") return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")

View File

@ -14,13 +14,13 @@ import yaml
from time import time from time import time
from attr import attrs, attrib, Factory from attr import attrs, attrib, Factory
from pathlib2 import Path from pathlib2 import Path
from packaging import version as packaging_version
from requirements import parse from requirements import parse
from requirements.requirement import Requirement from requirements.requirement import Requirement
from trains_agent.errors import CommandFailedError from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.helper.package.requirements import SimpleVersion
from trains_agent.session import Session from trains_agent.session import Session
from .base import PackageManager from .base import PackageManager
from .pip_api.venv import VirtualenvPip from .pip_api.venv import VirtualenvPip
@ -59,7 +59,7 @@ class CondaAPI(PackageManager):
A programmatic interface for controlling conda A programmatic interface for controlling conda
""" """
MINIMUM_VERSION = packaging_version.parse("4.3.30") MINIMUM_VERSION = "4.3.30"
def __init__(self, session, path, python, requirements_manager): def __init__(self, session, path, python, requirements_manager):
# type: (Session, PathLike, float, RequirementsManager) -> None # type: (Session, PathLike, float, RequirementsManager) -> None
@ -93,7 +93,7 @@ class CondaAPI(PackageManager):
) )
) )
self.conda_version = self.get_conda_version(output) self.conda_version = self.get_conda_version(output)
if packaging_version.parse(self.conda_version) < self.MINIMUM_VERSION: if SimpleVersion.compare_versions(self.conda_version, '<', self.MINIMUM_VERSION):
raise CommandFailedError( raise CommandFailedError(
"conda version '{}' is smaller than minimum supported conda version '{}'".format( "conda version '{}' is smaller than minimum supported conda version '{}'".format(
self.conda_version, self.MINIMUM_VERSION self.conda_version, self.MINIMUM_VERSION

View File

@ -10,11 +10,9 @@ from typing import Text
import attr import attr
import requests import requests
from packaging import version as packaging_version
from packaging.specifiers import SpecifierSet
import six import six
from .requirements import SimpleSubstitution, FatalSpecsResolutionError from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"} OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
@ -156,8 +154,7 @@ class PytorchRequirement(SimpleSubstitution):
self.os = os_name or self.get_platform() self.os = os_name or self.get_platform()
self.cuda = "cuda{}".format(self.cuda_version).lower() self.cuda = "cuda{}".format(self.cuda_version).lower()
self.python_version_string = str(self.config["agent.default_python"]) self.python_version_string = str(self.config["agent.default_python"])
self.python_major_minor_str = '.'.join(packaging_version.parse( self.python_major_minor_str = '.'.join(self.python_version_string.split('.')[:2])
self.python_version_string).base_version.split('.')[:2])
if '.' not in self.python_major_minor_str: if '.' not in self.python_major_minor_str:
raise PytorchResolutionError( raise PytorchResolutionError(
"invalid python version {!r} defined in configuration file, key 'agent.default_python': " "invalid python version {!r} defined in configuration file, key 'agent.default_python': "
@ -222,7 +219,6 @@ class PytorchRequirement(SimpleSubstitution):
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform() platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
py_ver = self.python_major_minor_str.replace('.', '') py_ver = self.python_major_minor_str.replace('.', '')
url = None url = None
spec = SpecifierSet(req.format_specs())
last_v = None last_v = None
# search for our package # search for our package
for l in links_parser.links: for l in links_parser.links:
@ -234,10 +230,11 @@ class PytorchRequirement(SimpleSubstitution):
# version (ignore +cpu +cu92 etc. + is %2B in the file link) # version (ignore +cpu +cu92 etc. + is %2B in the file link)
# version ignore .postX suffix (treat as regular version) # version ignore .postX suffix (treat as regular version)
try: try:
v = packaging_version.parse(parts[1].split('%')[0].split('+')[0]) v = str(parts[1].split('%')[0].split('+')[0])
except Exception: except Exception:
continue continue
if v not in spec or (last_v and last_v > v): if not req.compare_version(v) or \
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
continue continue
if not parts[2].endswith(py_ver): if not parts[2].endswith(py_ver):
continue continue
@ -307,20 +304,17 @@ class PytorchRequirement(SimpleSubstitution):
@staticmethod @staticmethod
def match_version(req, options): def match_version(req, options):
versioned_options = sorted( versioned_options = sorted(
((packaging_version.parse(fix_version(key)), value) for key, value in options.items()), ((fix_version(key), value) for key, value in options.items()),
key=itemgetter(0), key=itemgetter(0),
reverse=True, reverse=True,
) )
req.specs = [(op, fix_version(version)) for op, version in req.specs] req.specs = [(op, fix_version(version)) for op, version in req.specs]
if req.specs:
specs = SpecifierSet(req.format_specs())
else:
specs = None
try: try:
return next( return next(
replacement replacement
for version, replacement in versioned_options for version, replacement in versioned_options
if not specs or version in specs if req.compare_version(version)
) )
except StopIteration: except StopIteration:
raise PytorchResolutionError( raise PytorchResolutionError(

View File

@ -10,7 +10,6 @@ from operator import itemgetter
from os import path from os import path
from typing import Text, List, Type, Optional, Tuple, Dict from typing import Text, List, Type, Optional, Tuple, Dict
from packaging import version as packaging_version
from pathlib2 import Path from pathlib2 import Path
from pyhocon import ConfigTree from pyhocon import ConfigTree
from requirements import parse from requirements import parse
@ -69,9 +68,20 @@ class MarkerRequirement(object):
def __repr__(self): def __repr__(self):
return '{self.__class__.__name__}[{self}]'.format(self=self) return '{self.__class__.__name__}[{self}]'.format(self=self)
def format_specs(self): def format_specs(self, num_parts=None, max_num_parts=None):
max_num_parts = max_num_parts or num_parts
if max_num_parts is None or not self.specs:
return ','.join(starmap(operator.add, self.specs)) return ','.join(starmap(operator.add, self.specs))
op, version = self.specs[0]
for v in self._sub_versions_pep440:
version = version.replace(v, '.')
if num_parts:
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
else:
version = version.strip('.').split('.')[:max_num_parts]
return op+'.'.join(version)
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self.req, item) return getattr(self.req, item)
@ -99,6 +109,186 @@ class MarkerRequirement(object):
else: else:
self.specs = greater + smaller self.specs = greater + smaller
def compare_version(self, requested_version, op=None, num_parts=3):
"""
compare the requested version with the one we have in the spec,
If the requested version is 1.2.3 the self.spec should be 1.2.3*
If the requested version is 1.2 the self.spec should be 1.2*
etc.
:param str requested_version:
:param str op: '==', '>', '>=', '<=', '<', '~='
:param int num_parts: number of parts to compare
:return: True if we answer the requested version
"""
# if we have no specific version, we cannot compare, so assume it's okay
if not self.specs:
return True
version = self.specs[0][1]
op = (op or self.specs[0][0]).strip()
return SimpleVersion.compare_versions(requested_version, op, version)
class SimpleVersion:
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
_local_version_separators = re.compile(r"[\._-]")
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
@classmethod
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
"""
Compare two versions based on the op operator
returns bool(version_a op version_b)
Notice: Ignores a/b/rc/post/dev markers on the version
:param str version_a:
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
:param str version_b:
:param bool ignore_sub_versions: if true compare only major.minor.patch
(ignore a/b/rc/post/dev in the comparison)
:return bool: version_a op version_b
"""
if not version_b:
return True
num_parts = 3
if op == '~=':
num_parts = max(num_parts, 2)
op = '=='
ignore_sub_versions = True
elif op == '===':
op = '=='
try:
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
except:
# revert to string based
for v in cls._sub_versions_pep440:
version_a = version_a.replace(v, '.')
version_b = version_b.replace(v, '.')
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_a_key = ''
version_b_key = ''
for i in range(num_parts):
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
version_a_key += pad.format(version_a[i])
version_b_key += pad.format(version_b[i])
if op == '==':
return version_a_key == version_b_key
if op == '<=':
return version_a_key <= version_b_key
if op == '>=':
return version_a_key >= version_b_key
if op == '>':
return version_a_key > version_b_key
if op == '<':
return version_a_key < version_b_key
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
@staticmethod
def _parse_letter_version(
letter, # type: str
number, # type: Union[str, bytes, SupportsInt]
):
# type: (...) -> Optional[Tuple[str, int]]
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
return ()
@staticmethod
def _get_match_key(match, num_parts, ignore_sub_versions):
if ignore_sub_versions:
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
(), (), (), (),)
return (
int(match.group("epoch")) if match.group("epoch") else 0,
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
SimpleVersion._parse_letter_version(
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
),
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
SimpleVersion._parse_local_version(match.group("local")),
)
@staticmethod
def _parse_local_version(local):
# type: (str) -> Optional[LocalType]
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in SimpleVersion._local_version_separators.split(local)
)
return ()
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class RequirementSubstitution(object): class RequirementSubstitution(object):
@ -177,7 +367,7 @@ class SimpleSubstitution(RequirementSubstitution):
if req.specs: if req.specs:
_, version_number = req.specs[0] _, version_number = req.specs[0]
assert packaging_version.parse(version_number) # assert packaging_version.parse(version_number)
else: else:
version_number = self.get_pip_version(self.name) version_number = self.get_pip_version(self.name)