From 58eb5fbd5f1ba5581de4e0b07ceb96feb1d88046 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 26 Nov 2020 01:14:36 +0200 Subject: [PATCH] Fix torch CUDA 11.1 support --- trains_agent/helper/package/pytorch.py | 41 ++++++++++++++++---------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 7150b25..bb5adc1 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution): 92: 'https://download.pytorch.org/whl/cu92/torch_stable.html', 100: 'https://download.pytorch.org/whl/cu100/torch_stable.html', 101: 'https://download.pytorch.org/whl/cu101/torch_stable.html', + 102: 'https://download.pytorch.org/whl/cu102/torch_stable.html', + 110: 'https://download.pytorch.org/whl/cu110/torch_stable.html', } def __init__(self, *args, **kwargs): @@ -117,20 +119,24 @@ class SimplePytorchRequirement(SimpleSubstitution): @classmethod def get_torch_page(cls, cuda_version, nightly=False): + # noinspection PyBroadException try: cuda = int(cuda_version) - except: + except Exception: cuda = 0 if nightly: - # then try the nightly builds, it might be there... - torch_url = cls.nightly_page_lookup_template.format(cuda) - try: - if requests.get(torch_url, timeout=10).ok: - cls.torch_page_lookup[cuda] = torch_url - return cls.torch_page_lookup[cuda], cuda - except Exception: - pass + for c in range(cuda, max(-1, cuda-15), -1): + # then try the nightly builds, it might be there... + torch_url = cls.nightly_page_lookup_template.format(c) + # noinspection PyBroadException + try: + if requests.get(torch_url, timeout=10).ok: + print('Torch nightly CUDA {} download page found'.format(c)) + cls.torch_page_lookup[c] = torch_url + return cls.torch_page_lookup[c], c + except Exception: + pass return # first check if key is valid @@ -138,13 +144,16 @@ class SimplePytorchRequirement(SimpleSubstitution): return cls.torch_page_lookup[cuda], cuda # then try a new cuda version page - torch_url = cls.page_lookup_template.format(cuda) - try: - if requests.get(torch_url, timeout=10).ok: - cls.torch_page_lookup[cuda] = torch_url - return cls.torch_page_lookup[cuda], cuda - except Exception: - pass + for c in range(cuda, max(-1, cuda-15), -1): + torch_url = cls.page_lookup_template.format(c) + # noinspection PyBroadException + try: + if requests.get(torch_url, timeout=10).ok: + print('Torch CUDA {} download page found'.format(c)) + cls.torch_page_lookup[c] = torch_url + return cls.torch_page_lookup[c], c + except Exception: + pass keys = sorted(cls.torch_page_lookup.keys(), reverse=True) for k in keys: