From 1afa3a39147aa5f29bedf47ce668bf973ee4c575 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 6 Dec 2020 12:15:12 +0200 Subject: [PATCH] Add torchcsprng and torchtext to PyTorch resolving. Improve debug prints on auto cuda version resolving. --- trains_agent/helper/package/pytorch.py | 52 ++++++++++++++++++-------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index bb5adc1..7b5a277 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -166,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution): class PytorchRequirement(SimpleSubstitution): name = "torch" - packages = ("torch", "torchvision", "torchaudio") + packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext") def __init__(self, *args, **kwargs): os_name = kwargs.pop("os_override", None) @@ -244,6 +244,7 @@ class PytorchRequirement(SimpleSubstitution): py_ver = self.python_major_minor_str.replace('.', '') url = None last_v = None + closest_v = None # search for our package for l in links_parser.links: parts = l.split('/')[-1].split('-') @@ -253,28 +254,40 @@ class PytorchRequirement(SimpleSubstitution): continue # version (ignore +cpu +cu92 etc. + is %2B in the file link) # version ignore .postX suffix (treat as regular version) + # noinspection PyBroadException try: v = str(parts[1].split('%')[0].split('+')[0]) except Exception: continue + if len(parts) < 3 or not parts[2].endswith(py_ver): + continue + if len(parts) < 5 or platform_wheel not in parts[4]: + continue + # update the closest matched version (from above) + if not closest_v: + closest_v = v + elif SimpleVersion.compare_versions( + version_a=closest_v, op='>=', version_b=v, num_parts=3) and \ + SimpleVersion.compare_versions( + version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3): + closest_v = v + # check if this an actual match if not req.compare_version(v) or \ (last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)): continue - if not parts[2].endswith(py_ver): - continue - if platform_wheel not in parts[4]: - continue + url = '/'.join(torch_url.split('/')[:-1] + l.split('/')) last_v = v # if we found an exact match, use it + # noinspection PyBroadException try: if req.specs[0][0] == '==' and \ SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False): break - except: + except Exception: pass - return url + return url, last_v or closest_v def get_url_for_platform(self, req): # check if package is already installed with system packages @@ -307,23 +320,28 @@ class PytorchRequirement(SimpleSubstitution): # assert op == "==" torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) - url = self._get_link_from_torch_page(req, torch_url) + url, closest_matched_version = self._get_link_from_torch_page(req, torch_url) if not url and self.config.get("agent.package_manager.torch_nightly", None): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True) - url = self._get_link_from_torch_page(req, torch_url) + url, closest_matched_version = self._get_link_from_torch_page(req, torch_url) # try one more time, with a lower cuda version (never fallback to CPU): while not url and torch_url_key > 0: previous_cuda_key = torch_url_key + print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format( + req, previous_cuda_key, closest_matched_version)) + url, closest_matched_version = self._get_link_from_torch_page(req, torch_url) + if url: + break torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1) # never fallback to CPU if torch_url_key < 1: - print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format( - req, previous_cuda_key)) - raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format( - req, self.cuda_version)) - print('Warning! Could not locate PyTorch version {} matching CUDA version {}, ' - 'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key)) - url = self._get_link_from_torch_page(req, torch_url) + print( + 'Error! Could not locate PyTorch version {} matching CUDA version {}'.format( + req, previous_cuda_key)) + raise ValueError( + 'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version)) + else: + print('Trying PyTorch CUDA version {} support'.format(torch_url_key)) if not url: url = PytorchWheel( @@ -335,6 +353,8 @@ class PytorchRequirement(SimpleSubstitution): if url: # normalize url (sometimes we will get ../ which we should not... url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize())) + # print found + print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key)) self.log.debug("checking url: %s", url) return url, requests.head(url, timeout=10).ok