mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Fix PyTorch support to ignore minor versions when looking for package to install or to download
This commit is contained in:
parent
98a983d9a2
commit
5ef627165c
@ -5,7 +5,6 @@ future>=0.16.0
|
|||||||
humanfriendly>=2.1
|
humanfriendly>=2.1
|
||||||
jsonmodels>=2.2
|
jsonmodels>=2.2
|
||||||
jsonschema>=2.6.0
|
jsonschema>=2.6.0
|
||||||
packaging>=16.0
|
|
||||||
pathlib2>=2.3.0
|
pathlib2>=2.3.0
|
||||||
psutil>=3.4.2
|
psutil>=3.4.2
|
||||||
pyhocon>=0.3.38
|
pyhocon>=0.3.38
|
||||||
|
@ -4,7 +4,7 @@ from time import sleep
|
|||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from packaging import version as packaging_version
|
from .package.requirements import SimpleVersion
|
||||||
from ..version import __version__
|
from ..version import __version__
|
||||||
|
|
||||||
__check_update_thread = None
|
__check_update_thread = None
|
||||||
@ -30,11 +30,11 @@ def _check_new_version_available():
|
|||||||
return None
|
return None
|
||||||
trains_answer = update_server_releases.get("trains-agent", {})
|
trains_answer = update_server_releases.get("trains-agent", {})
|
||||||
latest_version = trains_answer.get("version")
|
latest_version = trains_answer.get("version")
|
||||||
cur_version = packaging_version.parse(cur_version)
|
cur_version = cur_version
|
||||||
latest_version = packaging_version.parse(latest_version or '')
|
latest_version = latest_version or ''
|
||||||
if cur_version >= latest_version:
|
if SimpleVersion.compare_versions(cur_version, '>=', latest_version):
|
||||||
return None
|
return None
|
||||||
patch_upgrade = latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
|
patch_upgrade = True # latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
|
||||||
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")
|
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,13 +14,13 @@ import yaml
|
|||||||
from time import time
|
from time import time
|
||||||
from attr import attrs, attrib, Factory
|
from attr import attrs, attrib, Factory
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from packaging import version as packaging_version
|
|
||||||
from requirements import parse
|
from requirements import parse
|
||||||
from requirements.requirement import Requirement
|
from requirements.requirement import Requirement
|
||||||
|
|
||||||
from trains_agent.errors import CommandFailedError
|
from trains_agent.errors import CommandFailedError
|
||||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
|
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
|
||||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||||
|
from trains_agent.helper.package.requirements import SimpleVersion
|
||||||
from trains_agent.session import Session
|
from trains_agent.session import Session
|
||||||
from .base import PackageManager
|
from .base import PackageManager
|
||||||
from .pip_api.venv import VirtualenvPip
|
from .pip_api.venv import VirtualenvPip
|
||||||
@ -59,7 +59,7 @@ class CondaAPI(PackageManager):
|
|||||||
A programmatic interface for controlling conda
|
A programmatic interface for controlling conda
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINIMUM_VERSION = packaging_version.parse("4.3.30")
|
MINIMUM_VERSION = "4.3.30"
|
||||||
|
|
||||||
def __init__(self, session, path, python, requirements_manager):
|
def __init__(self, session, path, python, requirements_manager):
|
||||||
# type: (Session, PathLike, float, RequirementsManager) -> None
|
# type: (Session, PathLike, float, RequirementsManager) -> None
|
||||||
@ -93,7 +93,7 @@ class CondaAPI(PackageManager):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.conda_version = self.get_conda_version(output)
|
self.conda_version = self.get_conda_version(output)
|
||||||
if packaging_version.parse(self.conda_version) < self.MINIMUM_VERSION:
|
if SimpleVersion.compare_versions(self.conda_version, '<', self.MINIMUM_VERSION):
|
||||||
raise CommandFailedError(
|
raise CommandFailedError(
|
||||||
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
|
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
|
||||||
self.conda_version, self.MINIMUM_VERSION
|
self.conda_version, self.MINIMUM_VERSION
|
||||||
|
@ -10,11 +10,9 @@ from typing import Text
|
|||||||
|
|
||||||
import attr
|
import attr
|
||||||
import requests
|
import requests
|
||||||
from packaging import version as packaging_version
|
|
||||||
from packaging.specifiers import SpecifierSet
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError
|
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
|
||||||
|
|
||||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||||
|
|
||||||
@ -156,8 +154,7 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
self.os = os_name or self.get_platform()
|
self.os = os_name or self.get_platform()
|
||||||
self.cuda = "cuda{}".format(self.cuda_version).lower()
|
self.cuda = "cuda{}".format(self.cuda_version).lower()
|
||||||
self.python_version_string = str(self.config["agent.default_python"])
|
self.python_version_string = str(self.config["agent.default_python"])
|
||||||
self.python_major_minor_str = '.'.join(packaging_version.parse(
|
self.python_major_minor_str = '.'.join(self.python_version_string.split('.')[:2])
|
||||||
self.python_version_string).base_version.split('.')[:2])
|
|
||||||
if '.' not in self.python_major_minor_str:
|
if '.' not in self.python_major_minor_str:
|
||||||
raise PytorchResolutionError(
|
raise PytorchResolutionError(
|
||||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||||
@ -222,7 +219,6 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
|
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
|
||||||
py_ver = self.python_major_minor_str.replace('.', '')
|
py_ver = self.python_major_minor_str.replace('.', '')
|
||||||
url = None
|
url = None
|
||||||
spec = SpecifierSet(req.format_specs())
|
|
||||||
last_v = None
|
last_v = None
|
||||||
# search for our package
|
# search for our package
|
||||||
for l in links_parser.links:
|
for l in links_parser.links:
|
||||||
@ -234,10 +230,11 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||||
# version ignore .postX suffix (treat as regular version)
|
# version ignore .postX suffix (treat as regular version)
|
||||||
try:
|
try:
|
||||||
v = packaging_version.parse(parts[1].split('%')[0].split('+')[0])
|
v = str(parts[1].split('%')[0].split('+')[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
if v not in spec or (last_v and last_v > v):
|
if not req.compare_version(v) or \
|
||||||
|
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||||
continue
|
continue
|
||||||
if not parts[2].endswith(py_ver):
|
if not parts[2].endswith(py_ver):
|
||||||
continue
|
continue
|
||||||
@ -307,20 +304,17 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def match_version(req, options):
|
def match_version(req, options):
|
||||||
versioned_options = sorted(
|
versioned_options = sorted(
|
||||||
((packaging_version.parse(fix_version(key)), value) for key, value in options.items()),
|
((fix_version(key), value) for key, value in options.items()),
|
||||||
key=itemgetter(0),
|
key=itemgetter(0),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
req.specs = [(op, fix_version(version)) for op, version in req.specs]
|
req.specs = [(op, fix_version(version)) for op, version in req.specs]
|
||||||
if req.specs:
|
|
||||||
specs = SpecifierSet(req.format_specs())
|
|
||||||
else:
|
|
||||||
specs = None
|
|
||||||
try:
|
try:
|
||||||
return next(
|
return next(
|
||||||
replacement
|
replacement
|
||||||
for version, replacement in versioned_options
|
for version, replacement in versioned_options
|
||||||
if not specs or version in specs
|
if req.compare_version(version)
|
||||||
)
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise PytorchResolutionError(
|
raise PytorchResolutionError(
|
||||||
|
@ -10,7 +10,6 @@ from operator import itemgetter
|
|||||||
from os import path
|
from os import path
|
||||||
from typing import Text, List, Type, Optional, Tuple, Dict
|
from typing import Text, List, Type, Optional, Tuple, Dict
|
||||||
|
|
||||||
from packaging import version as packaging_version
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from pyhocon import ConfigTree
|
from pyhocon import ConfigTree
|
||||||
from requirements import parse
|
from requirements import parse
|
||||||
@ -69,9 +68,20 @@ class MarkerRequirement(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
||||||
|
|
||||||
def format_specs(self):
|
def format_specs(self, num_parts=None, max_num_parts=None):
|
||||||
|
max_num_parts = max_num_parts or num_parts
|
||||||
|
if max_num_parts is None or not self.specs:
|
||||||
return ','.join(starmap(operator.add, self.specs))
|
return ','.join(starmap(operator.add, self.specs))
|
||||||
|
|
||||||
|
op, version = self.specs[0]
|
||||||
|
for v in self._sub_versions_pep440:
|
||||||
|
version = version.replace(v, '.')
|
||||||
|
if num_parts:
|
||||||
|
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
||||||
|
else:
|
||||||
|
version = version.strip('.').split('.')[:max_num_parts]
|
||||||
|
return op+'.'.join(version)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return getattr(self.req, item)
|
return getattr(self.req, item)
|
||||||
|
|
||||||
@ -99,6 +109,186 @@ class MarkerRequirement(object):
|
|||||||
else:
|
else:
|
||||||
self.specs = greater + smaller
|
self.specs = greater + smaller
|
||||||
|
|
||||||
|
def compare_version(self, requested_version, op=None, num_parts=3):
|
||||||
|
"""
|
||||||
|
compare the requested version with the one we have in the spec,
|
||||||
|
If the requested version is 1.2.3 the self.spec should be 1.2.3*
|
||||||
|
If the requested version is 1.2 the self.spec should be 1.2*
|
||||||
|
etc.
|
||||||
|
|
||||||
|
:param str requested_version:
|
||||||
|
:param str op: '==', '>', '>=', '<=', '<', '~='
|
||||||
|
:param int num_parts: number of parts to compare
|
||||||
|
:return: True if we answer the requested version
|
||||||
|
"""
|
||||||
|
# if we have no specific version, we cannot compare, so assume it's okay
|
||||||
|
if not self.specs:
|
||||||
|
return True
|
||||||
|
|
||||||
|
version = self.specs[0][1]
|
||||||
|
op = (op or self.specs[0][0]).strip()
|
||||||
|
|
||||||
|
return SimpleVersion.compare_versions(requested_version, op, version)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleVersion:
|
||||||
|
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
|
||||||
|
VERSION_PATTERN = r"""
|
||||||
|
v?
|
||||||
|
(?:
|
||||||
|
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||||
|
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||||
|
(?P<pre> # pre-release
|
||||||
|
[-_\.]?
|
||||||
|
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||||
|
[-_\.]?
|
||||||
|
(?P<pre_n>[0-9]+)?
|
||||||
|
)?
|
||||||
|
(?P<post> # post release
|
||||||
|
(?:-(?P<post_n1>[0-9]+))
|
||||||
|
|
|
||||||
|
(?:
|
||||||
|
[-_\.]?
|
||||||
|
(?P<post_l>post|rev|r)
|
||||||
|
[-_\.]?
|
||||||
|
(?P<post_n2>[0-9]+)?
|
||||||
|
)
|
||||||
|
)?
|
||||||
|
(?P<dev> # dev release
|
||||||
|
[-_\.]?
|
||||||
|
(?P<dev_l>dev)
|
||||||
|
[-_\.]?
|
||||||
|
(?P<dev_n>[0-9]+)?
|
||||||
|
)?
|
||||||
|
)
|
||||||
|
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||||
|
"""
|
||||||
|
_local_version_separators = re.compile(r"[\._-]")
|
||||||
|
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
|
||||||
|
"""
|
||||||
|
Compare two versions based on the op operator
|
||||||
|
returns bool(version_a op version_b)
|
||||||
|
Notice: Ignores a/b/rc/post/dev markers on the version
|
||||||
|
|
||||||
|
:param str version_a:
|
||||||
|
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
|
||||||
|
:param str version_b:
|
||||||
|
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
||||||
|
(ignore a/b/rc/post/dev in the comparison)
|
||||||
|
:return bool: version_a op version_b
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not version_b:
|
||||||
|
return True
|
||||||
|
num_parts = 3
|
||||||
|
|
||||||
|
if op == '~=':
|
||||||
|
num_parts = max(num_parts, 2)
|
||||||
|
op = '=='
|
||||||
|
ignore_sub_versions = True
|
||||||
|
elif op == '===':
|
||||||
|
op = '=='
|
||||||
|
|
||||||
|
try:
|
||||||
|
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
|
||||||
|
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
|
||||||
|
except:
|
||||||
|
# revert to string based
|
||||||
|
for v in cls._sub_versions_pep440:
|
||||||
|
version_a = version_a.replace(v, '.')
|
||||||
|
version_b = version_b.replace(v, '.')
|
||||||
|
|
||||||
|
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||||
|
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||||
|
version_a_key = ''
|
||||||
|
version_b_key = ''
|
||||||
|
for i in range(num_parts):
|
||||||
|
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
|
||||||
|
version_a_key += pad.format(version_a[i])
|
||||||
|
version_b_key += pad.format(version_b[i])
|
||||||
|
|
||||||
|
if op == '==':
|
||||||
|
return version_a_key == version_b_key
|
||||||
|
if op == '<=':
|
||||||
|
return version_a_key <= version_b_key
|
||||||
|
if op == '>=':
|
||||||
|
return version_a_key >= version_b_key
|
||||||
|
if op == '>':
|
||||||
|
return version_a_key > version_b_key
|
||||||
|
if op == '<':
|
||||||
|
return version_a_key < version_b_key
|
||||||
|
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_letter_version(
|
||||||
|
letter, # type: str
|
||||||
|
number, # type: Union[str, bytes, SupportsInt]
|
||||||
|
):
|
||||||
|
# type: (...) -> Optional[Tuple[str, int]]
|
||||||
|
|
||||||
|
if letter:
|
||||||
|
# We consider there to be an implicit 0 in a pre-release if there is
|
||||||
|
# not a numeral associated with it.
|
||||||
|
if number is None:
|
||||||
|
number = 0
|
||||||
|
|
||||||
|
# We normalize any letters to their lower case form
|
||||||
|
letter = letter.lower()
|
||||||
|
|
||||||
|
# We consider some words to be alternate spellings of other words and
|
||||||
|
# in those cases we want to normalize the spellings to our preferred
|
||||||
|
# spelling.
|
||||||
|
if letter == "alpha":
|
||||||
|
letter = "a"
|
||||||
|
elif letter == "beta":
|
||||||
|
letter = "b"
|
||||||
|
elif letter in ["c", "pre", "preview"]:
|
||||||
|
letter = "rc"
|
||||||
|
elif letter in ["rev", "r"]:
|
||||||
|
letter = "post"
|
||||||
|
|
||||||
|
return letter, int(number)
|
||||||
|
if not letter and number:
|
||||||
|
# We assume if we are given a number, but we are not given a letter
|
||||||
|
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
||||||
|
letter = "post"
|
||||||
|
|
||||||
|
return letter, int(number)
|
||||||
|
|
||||||
|
return ()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_match_key(match, num_parts, ignore_sub_versions):
|
||||||
|
if ignore_sub_versions:
|
||||||
|
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||||
|
(), (), (), (),)
|
||||||
|
return (
|
||||||
|
int(match.group("epoch")) if match.group("epoch") else 0,
|
||||||
|
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||||
|
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
||||||
|
SimpleVersion._parse_letter_version(
|
||||||
|
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
||||||
|
),
|
||||||
|
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
||||||
|
SimpleVersion._parse_local_version(match.group("local")),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_local_version(local):
|
||||||
|
# type: (str) -> Optional[LocalType]
|
||||||
|
"""
|
||||||
|
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||||
|
"""
|
||||||
|
if local is not None:
|
||||||
|
return tuple(
|
||||||
|
part.lower() if not part.isdigit() else int(part)
|
||||||
|
for part in SimpleVersion._local_version_separators.split(local)
|
||||||
|
)
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(ABCMeta)
|
@six.add_metaclass(ABCMeta)
|
||||||
class RequirementSubstitution(object):
|
class RequirementSubstitution(object):
|
||||||
@ -177,7 +367,7 @@ class SimpleSubstitution(RequirementSubstitution):
|
|||||||
|
|
||||||
if req.specs:
|
if req.specs:
|
||||||
_, version_number = req.specs[0]
|
_, version_number = req.specs[0]
|
||||||
assert packaging_version.parse(version_number)
|
# assert packaging_version.parse(version_number)
|
||||||
else:
|
else:
|
||||||
version_number = self.get_pip_version(self.name)
|
version_number = self.get_pip_version(self.name)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user