diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 676e8be..10b682f 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -249,9 +249,18 @@ class PytorchRequirement(SimpleSubstitution): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) url = self._get_link_from_torch_page(req, torch_url) - # try one more time, with a lower cuda version: - if not url: + # try one more time, with a lower cuda version (never fallback to CPU): + while not url and torch_url_key > 0: + previous_cuda_key = torch_url_key torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1) + # never fallback to CPU + if torch_url_key < 1: + print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format( + req, previous_cuda_key)) + raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format( + req, self.cuda_version)) + print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format( + req, previous_cuda_key, torch_url_key)) url = self._get_link_from_torch_page(req, torch_url) if not url: diff --git a/trains_agent/helper/package/requirements.py b/trains_agent/helper/package/requirements.py index 0bd3338..8a0c048 100644 --- a/trains_agent/helper/package/requirements.py +++ b/trains_agent/helper/package/requirements.py @@ -255,9 +255,10 @@ class RequirementsManager(object): try: return self._replace_one(req) except FatalSpecsResolutionError: + warning('could not resolve python wheel replacement for {}'.format(req)) raise except Exception: - warning('could not find installed CUDA/CuDNN version for {}, ' + warning('could not resolve python wheel replacement for {}, ' 'using original requirements line: {}'.format(req, i)) return None