From efb06891a8b72e3f8b175db6b77003086e1012dc Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 23 Oct 2022 12:59:29 +0300 Subject: [PATCH] Add support for PyTorch new extra_index_url repo support. We will find the correct index url based on the cuda version, and let pip do the rest. --- clearml_agent/helper/package/pytorch.py | 93 +++++++++++++++++++- clearml_agent/helper/package/requirements.py | 10 ++- 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/clearml_agent/helper/package/pytorch.py b/clearml_agent/helper/package/pytorch.py index 096746b..41ce7d8 100644 --- a/clearml_agent/helper/package/pytorch.py +++ b/clearml_agent/helper/package/pytorch.py @@ -13,7 +13,9 @@ import attr import requests import six -from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement +from .requirements import ( + SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement, + compare_version_rules, ) from ...external.requirements_parser.requirement import Requirement OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"} @@ -169,6 +171,10 @@ class PytorchRequirement(SimpleSubstitution): name = "torch" packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext") + extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/' + nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/' + torch_index_url_lookup = {} + def __init__(self, *args, **kwargs): os_name = kwargs.pop("os_override", None) super(PytorchRequirement, self).__init__(*args, **kwargs) @@ -183,6 +189,13 @@ class PytorchRequirement(SimpleSubstitution): self.exceptions = [] self._original_req = [] # allow override pytorch lookup pages + if self.config.get("agent.package_manager.extra_index_url_template", None): + self.extra_index_url_template = \ + self.config.get("agent.package_manager.extra_index_url_template", None) + if self.config.get("agent.package_manager.nightly_extra_index_url_template", None): + self.nightly_extra_index_url_template = \ + self.config.get("agent.package_manager.nightly_extra_index_url_template", None) + # allow override pytorch lookup pages if self.config.get("agent.package_manager.torch_page", None): SimplePytorchRequirement.page_lookup_template = \ self.config.get("agent.package_manager.torch_page", None) @@ -381,7 +394,8 @@ class PytorchRequirement(SimpleSubstitution): print('Trying PyTorch CUDA version {} support'.format(torch_url_key)) # fix broken pytorch setuptools incompatibility - if closest_matched_version and SimpleVersion.compare_versions(closest_matched_version, "<", "1.11.0"): + if req.name == "torch" and closest_matched_version and \ + SimpleVersion.compare_versions(closest_matched_version, "<", "1.11.0"): self._fix_setuptools = "setuptools < 59" if not url: @@ -461,6 +475,36 @@ class PytorchRequirement(SimpleSubstitution): return self.match_version(req, base).replace(" ", "\n") def replace(self, req): + # check if package is already installed with system packages + self.validate_python_version() + + # try to check if we can just use the new index URL, if we do not we will revert to old method + try: + extra_index_url = self.get_torch_index_url(self.cuda_version) + if extra_index_url: + # check if the torch version cannot be above 1.11 , we need to fix setup tools + try: + if req.name == "torch" and not compare_version_rules(req.specs, [(">=", "1.11.0")]): + self._fix_setuptools = "setuptools < 59" + except Exception: # noqa + pass + # now we just need to add the correct extra index url for the cuda version + self.set_add_install_extra_index(extra_index_url[0]) + + if req.specs and len(req.specs) == 1 and req.specs[0][0] == "==": + # remove any +cu extension and let pip resolve that + line = "{} {}".format(req.name, req.format_specs(max_num_parts=3)) + if req.marker: + line += " ; {}".format(req.marker) + else: + # return the original line + line = req.line + + return line + + except Exception: # noqa + pass + try: new_req = self._replace(req) if new_req: @@ -556,6 +600,51 @@ class PytorchRequirement(SimpleSubstitution): return MarkerRequirement(Requirement.parse(self._fix_setuptools)) return None + @classmethod + def get_torch_index_url(cls, cuda_version, nightly=False): + # noinspection PyBroadException + try: + cuda = int(cuda_version) + except Exception: + cuda = 0 + + if nightly: + for c in range(cuda, max(-1, cuda-15), -1): + # then try the nightly builds, it might be there... + torch_url = cls.nightly_extra_index_url_template.format(c) + # noinspection PyBroadException + try: + if requests.get(torch_url, timeout=10).ok: + print('Torch nightly CUDA {} index page found'.format(c)) + cls.torch_index_url_lookup[c] = torch_url + return cls.torch_index_url_lookup[c], c + except Exception: + pass + return + + # first check if key is valid + if cuda in cls.torch_index_url_lookup: + return cls.torch_index_url_lookup[cuda], cuda + + # then try a new cuda version page + for c in range(cuda, max(-1, cuda-15), -1): + torch_url = cls.extra_index_url_template.format(c) + # noinspection PyBroadException + try: + if requests.get(torch_url, timeout=10).ok: + print('Torch CUDA {} index page found'.format(c)) + cls.torch_index_url_lookup[c] = torch_url + return cls.torch_index_url_lookup[c], c + except Exception: + pass + + keys = sorted(cls.torch_index_url_lookup.keys(), reverse=True) + for k in keys: + if k <= cuda: + return cls.torch_index_url_lookup[k], k + # return default - zero + return cls.torch_index_url_lookup[0], 0 + MAP = { "windows": { "cuda100": { diff --git a/clearml_agent/helper/package/requirements.py b/clearml_agent/helper/package/requirements.py index fe3c1ce..f04cdda 100644 --- a/clearml_agent/helper/package/requirements.py +++ b/clearml_agent/helper/package/requirements.py @@ -100,7 +100,8 @@ class MarkerRequirement(object): return ','.join(starmap(operator.add, self.specs)) op, version = self.specs[0] - for v in self._sub_versions_pep440: + # noinspection PyProtectedMember + for v in SimpleVersion._sub_versions_pep440: version = version.replace(v, '.') if num_parts: version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts] @@ -364,7 +365,7 @@ def compare_version_rules(specs_a, specs_b): # specs_a/b are a list of tuples: [('==', '1.2.3'), ] or [('>=', '1.2'), ('<', '1.3')] # section definition: class Section(object): - def __init__(self, left=None, left_eq=False, right=None, right_eq=False): + def __init__(self, left="-999999999", left_eq=False, right="999999999", right_eq=False): self.left, self.left_eq, self.right, self.right_eq = left, left_eq, right, right_eq # first create a list of in/out sections for each spec # >, >= are left rule @@ -436,6 +437,11 @@ class RequirementSubstitution(object): _pip_extra_index_url = PIP_EXTRA_INDICES + @classmethod + def set_add_install_extra_index(cls, extra_index_url): + if extra_index_url not in cls._pip_extra_index_url: + cls._pip_extra_index_url.append(extra_index_url) + def __init__(self, session): # type: (Session) -> () self._session = session