Improve PyTorch CUDA auto package resolving

This commit is contained in:
allegroai 2019-10-28 21:54:11 +02:00
parent 5e6a809efd
commit 1765788e80
2 changed files with 13 additions and 3 deletions

View File

@ -249,9 +249,18 @@ class PytorchRequirement(SimpleSubstitution):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version:
if not url:
# try one more time, with a lower cuda version (never fallback to CPU):
while not url and torch_url_key > 0:
previous_cuda_key = torch_url_key
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
# never fallback to CPU
if torch_url_key < 1:
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
req, previous_cuda_key))
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
req, self.cuda_version))
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
req, previous_cuda_key, torch_url_key))
url = self._get_link_from_torch_page(req, torch_url)
if not url:

View File

@ -255,9 +255,10 @@ class RequirementsManager(object):
try:
return self._replace_one(req)
except FatalSpecsResolutionError:
warning('could not resolve python wheel replacement for {}'.format(req))
raise
except Exception:
warning('could not find installed CUDA/CuDNN version for {}, '
warning('could not resolve python wheel replacement for {}, '
'using original requirements line: {}'.format(req, i))
return None