Fix torch CUDA 11.1 support

This commit is contained in:
allegroai 2020-11-26 01:14:36 +02:00
parent a8c543ef7b
commit 58eb5fbd5f

View File

@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution):
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html', 92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html', 100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html', 101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -117,20 +119,24 @@ class SimplePytorchRequirement(SimpleSubstitution):
@classmethod @classmethod
def get_torch_page(cls, cuda_version, nightly=False): def get_torch_page(cls, cuda_version, nightly=False):
# noinspection PyBroadException
try: try:
cuda = int(cuda_version) cuda = int(cuda_version)
except: except Exception:
cuda = 0 cuda = 0
if nightly: if nightly:
# then try the nightly builds, it might be there... for c in range(cuda, max(-1, cuda-15), -1):
torch_url = cls.nightly_page_lookup_template.format(cuda) # then try the nightly builds, it might be there...
try: torch_url = cls.nightly_page_lookup_template.format(c)
if requests.get(torch_url, timeout=10).ok: # noinspection PyBroadException
cls.torch_page_lookup[cuda] = torch_url try:
return cls.torch_page_lookup[cuda], cuda if requests.get(torch_url, timeout=10).ok:
except Exception: print('Torch nightly CUDA {} download page found'.format(c))
pass cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
return return
# first check if key is valid # first check if key is valid
@ -138,13 +144,16 @@ class SimplePytorchRequirement(SimpleSubstitution):
return cls.torch_page_lookup[cuda], cuda return cls.torch_page_lookup[cuda], cuda
# then try a new cuda version page # then try a new cuda version page
torch_url = cls.page_lookup_template.format(cuda) for c in range(cuda, max(-1, cuda-15), -1):
try: torch_url = cls.page_lookup_template.format(c)
if requests.get(torch_url, timeout=10).ok: # noinspection PyBroadException
cls.torch_page_lookup[cuda] = torch_url try:
return cls.torch_page_lookup[cuda], cuda if requests.get(torch_url, timeout=10).ok:
except Exception: print('Torch CUDA {} download page found'.format(c))
pass cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
keys = sorted(cls.torch_page_lookup.keys(), reverse=True) keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
for k in keys: for k in keys: