mirror of
https://github.com/clearml/clearml-agent
synced 2025-05-01 10:54:11 +00:00
Add support for PyTorch new extra_index_url repo support. We will find the correct index url based on the cuda version, and let pip do the rest.
This commit is contained in:
parent
70771b12a9
commit
efb06891a8
@ -13,7 +13,9 @@ import attr
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement
|
from .requirements import (
|
||||||
|
SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement,
|
||||||
|
compare_version_rules, )
|
||||||
from ...external.requirements_parser.requirement import Requirement
|
from ...external.requirements_parser.requirement import Requirement
|
||||||
|
|
||||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||||
@ -169,6 +171,10 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
name = "torch"
|
name = "torch"
|
||||||
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||||
|
|
||||||
|
extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/'
|
||||||
|
nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/'
|
||||||
|
torch_index_url_lookup = {}
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os_name = kwargs.pop("os_override", None)
|
os_name = kwargs.pop("os_override", None)
|
||||||
super(PytorchRequirement, self).__init__(*args, **kwargs)
|
super(PytorchRequirement, self).__init__(*args, **kwargs)
|
||||||
@ -183,6 +189,13 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
self.exceptions = []
|
self.exceptions = []
|
||||||
self._original_req = []
|
self._original_req = []
|
||||||
# allow override pytorch lookup pages
|
# allow override pytorch lookup pages
|
||||||
|
if self.config.get("agent.package_manager.extra_index_url_template", None):
|
||||||
|
self.extra_index_url_template = \
|
||||||
|
self.config.get("agent.package_manager.extra_index_url_template", None)
|
||||||
|
if self.config.get("agent.package_manager.nightly_extra_index_url_template", None):
|
||||||
|
self.nightly_extra_index_url_template = \
|
||||||
|
self.config.get("agent.package_manager.nightly_extra_index_url_template", None)
|
||||||
|
# allow override pytorch lookup pages
|
||||||
if self.config.get("agent.package_manager.torch_page", None):
|
if self.config.get("agent.package_manager.torch_page", None):
|
||||||
SimplePytorchRequirement.page_lookup_template = \
|
SimplePytorchRequirement.page_lookup_template = \
|
||||||
self.config.get("agent.package_manager.torch_page", None)
|
self.config.get("agent.package_manager.torch_page", None)
|
||||||
@ -381,7 +394,8 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||||
|
|
||||||
# fix broken pytorch setuptools incompatibility
|
# fix broken pytorch setuptools incompatibility
|
||||||
if closest_matched_version and SimpleVersion.compare_versions(closest_matched_version, "<", "1.11.0"):
|
if req.name == "torch" and closest_matched_version and \
|
||||||
|
SimpleVersion.compare_versions(closest_matched_version, "<", "1.11.0"):
|
||||||
self._fix_setuptools = "setuptools < 59"
|
self._fix_setuptools = "setuptools < 59"
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
@ -461,6 +475,36 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
return self.match_version(req, base).replace(" ", "\n")
|
return self.match_version(req, base).replace(" ", "\n")
|
||||||
|
|
||||||
def replace(self, req):
|
def replace(self, req):
|
||||||
|
# check if package is already installed with system packages
|
||||||
|
self.validate_python_version()
|
||||||
|
|
||||||
|
# try to check if we can just use the new index URL, if we do not we will revert to old method
|
||||||
|
try:
|
||||||
|
extra_index_url = self.get_torch_index_url(self.cuda_version)
|
||||||
|
if extra_index_url:
|
||||||
|
# check if the torch version cannot be above 1.11 , we need to fix setup tools
|
||||||
|
try:
|
||||||
|
if req.name == "torch" and not compare_version_rules(req.specs, [(">=", "1.11.0")]):
|
||||||
|
self._fix_setuptools = "setuptools < 59"
|
||||||
|
except Exception: # noqa
|
||||||
|
pass
|
||||||
|
# now we just need to add the correct extra index url for the cuda version
|
||||||
|
self.set_add_install_extra_index(extra_index_url[0])
|
||||||
|
|
||||||
|
if req.specs and len(req.specs) == 1 and req.specs[0][0] == "==":
|
||||||
|
# remove any +cu extension and let pip resolve that
|
||||||
|
line = "{} {}".format(req.name, req.format_specs(max_num_parts=3))
|
||||||
|
if req.marker:
|
||||||
|
line += " ; {}".format(req.marker)
|
||||||
|
else:
|
||||||
|
# return the original line
|
||||||
|
line = req.line
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
except Exception: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_req = self._replace(req)
|
new_req = self._replace(req)
|
||||||
if new_req:
|
if new_req:
|
||||||
@ -556,6 +600,51 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
return MarkerRequirement(Requirement.parse(self._fix_setuptools))
|
return MarkerRequirement(Requirement.parse(self._fix_setuptools))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_torch_index_url(cls, cuda_version, nightly=False):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
cuda = int(cuda_version)
|
||||||
|
except Exception:
|
||||||
|
cuda = 0
|
||||||
|
|
||||||
|
if nightly:
|
||||||
|
for c in range(cuda, max(-1, cuda-15), -1):
|
||||||
|
# then try the nightly builds, it might be there...
|
||||||
|
torch_url = cls.nightly_extra_index_url_template.format(c)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if requests.get(torch_url, timeout=10).ok:
|
||||||
|
print('Torch nightly CUDA {} index page found'.format(c))
|
||||||
|
cls.torch_index_url_lookup[c] = torch_url
|
||||||
|
return cls.torch_index_url_lookup[c], c
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
# first check if key is valid
|
||||||
|
if cuda in cls.torch_index_url_lookup:
|
||||||
|
return cls.torch_index_url_lookup[cuda], cuda
|
||||||
|
|
||||||
|
# then try a new cuda version page
|
||||||
|
for c in range(cuda, max(-1, cuda-15), -1):
|
||||||
|
torch_url = cls.extra_index_url_template.format(c)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if requests.get(torch_url, timeout=10).ok:
|
||||||
|
print('Torch CUDA {} index page found'.format(c))
|
||||||
|
cls.torch_index_url_lookup[c] = torch_url
|
||||||
|
return cls.torch_index_url_lookup[c], c
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
keys = sorted(cls.torch_index_url_lookup.keys(), reverse=True)
|
||||||
|
for k in keys:
|
||||||
|
if k <= cuda:
|
||||||
|
return cls.torch_index_url_lookup[k], k
|
||||||
|
# return default - zero
|
||||||
|
return cls.torch_index_url_lookup[0], 0
|
||||||
|
|
||||||
MAP = {
|
MAP = {
|
||||||
"windows": {
|
"windows": {
|
||||||
"cuda100": {
|
"cuda100": {
|
||||||
|
@ -100,7 +100,8 @@ class MarkerRequirement(object):
|
|||||||
return ','.join(starmap(operator.add, self.specs))
|
return ','.join(starmap(operator.add, self.specs))
|
||||||
|
|
||||||
op, version = self.specs[0]
|
op, version = self.specs[0]
|
||||||
for v in self._sub_versions_pep440:
|
# noinspection PyProtectedMember
|
||||||
|
for v in SimpleVersion._sub_versions_pep440:
|
||||||
version = version.replace(v, '.')
|
version = version.replace(v, '.')
|
||||||
if num_parts:
|
if num_parts:
|
||||||
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
||||||
@ -364,7 +365,7 @@ def compare_version_rules(specs_a, specs_b):
|
|||||||
# specs_a/b are a list of tuples: [('==', '1.2.3'), ] or [('>=', '1.2'), ('<', '1.3')]
|
# specs_a/b are a list of tuples: [('==', '1.2.3'), ] or [('>=', '1.2'), ('<', '1.3')]
|
||||||
# section definition:
|
# section definition:
|
||||||
class Section(object):
|
class Section(object):
|
||||||
def __init__(self, left=None, left_eq=False, right=None, right_eq=False):
|
def __init__(self, left="-999999999", left_eq=False, right="999999999", right_eq=False):
|
||||||
self.left, self.left_eq, self.right, self.right_eq = left, left_eq, right, right_eq
|
self.left, self.left_eq, self.right, self.right_eq = left, left_eq, right, right_eq
|
||||||
# first create a list of in/out sections for each spec
|
# first create a list of in/out sections for each spec
|
||||||
# >, >= are left rule
|
# >, >= are left rule
|
||||||
@ -436,6 +437,11 @@ class RequirementSubstitution(object):
|
|||||||
|
|
||||||
_pip_extra_index_url = PIP_EXTRA_INDICES
|
_pip_extra_index_url = PIP_EXTRA_INDICES
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_add_install_extra_index(cls, extra_index_url):
|
||||||
|
if extra_index_url not in cls._pip_extra_index_url:
|
||||||
|
cls._pip_extra_index_url.append(extra_index_url)
|
||||||
|
|
||||||
def __init__(self, session):
|
def __init__(self, session):
|
||||||
# type: (Session) -> ()
|
# type: (Session) -> ()
|
||||||
self._session = session
|
self._session = session
|
||||||
|
Loading…
Reference in New Issue
Block a user