diff --git a/docs/trains.conf b/docs/trains.conf index f8df36a..56f68c2 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -55,6 +55,10 @@ agent { # additional conda channels to use when installing with conda package manager conda_channels: ["pytorch", "conda-forge", ] + + # set to True to support torch nightly build installation, + # notice: torch nightly builds are ephemeral and are deleted from time to time + torch_nightly: false, }, # target folder for virtual environments builds, created when executing experiment diff --git a/trains_agent/backend_api/config/default/agent.conf b/trains_agent/backend_api/config/default/agent.conf index 86792c2..e22da79 100644 --- a/trains_agent/backend_api/config/default/agent.conf +++ b/trains_agent/backend_api/config/default/agent.conf @@ -39,6 +39,10 @@ # additional conda channels to use when installing with conda package manager conda_channels: ["defaults", "conda-forge", "pytorch", ] + + # set to True to support torch nightly build installation, + # notice: torch nightly builds are ephemeral and are deleted from time to time + torch_nightly: false, }, # target folder for virtual environments builds, created when executing experiment diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 8b13c57..f492cdc 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -74,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution): packages = ("torch", "torchvision", "torchaudio") page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html' + nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html' torch_page_lookup = { 0: 'https://download.pytorch.org/whl/cpu/torch_stable.html', 80: 'https://download.pytorch.org/whl/cu80/torch_stable.html', @@ -115,11 +116,23 @@ class SimplePytorchRequirement(SimpleSubstitution): package_manager.add_extra_install_flags(('-f', extra_url)) @classmethod - def get_torch_page(cls, cuda_version): + def get_torch_page(cls, cuda_version, nightly=False): try: cuda = int(cuda_version) except: cuda = 0 + + if nightly: + # then try the nightly builds, it might be there... + torch_url = cls.nightly_page_lookup_template.format(cuda) + try: + if requests.get(torch_url, timeout=10).ok: + cls.torch_page_lookup[cuda] = torch_url + return cls.torch_page_lookup[cuda], cuda + except Exception: + pass + return + # first check if key is valid if cuda in cls.torch_page_lookup: return cls.torch_page_lookup[cuda], cuda @@ -180,6 +193,8 @@ class PytorchRequirement(SimpleSubstitution): except PytorchResolutionError as e: self.log.warn("will not be able to install pytorch wheels: %s", e.args[0]) + self._original_req = [] + @property def is_conda(self): return self.package_manager == "conda" @@ -242,6 +257,13 @@ class PytorchRequirement(SimpleSubstitution): continue url = '/'.join(torch_url.split('/')[:-1] + l.split('/')) last_v = v + # if we found an exact match, use it + try: + if req.specs[0][0] == '==' and \ + SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False): + break + except: + pass return url @@ -273,6 +295,9 @@ class PytorchRequirement(SimpleSubstitution): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) url = self._get_link_from_torch_page(req, torch_url) + if not url and self.config.get("agent.package_manager.torch_nightly"): + torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True) + url = self._get_link_from_torch_page(req, torch_url) # try one more time, with a lower cuda version (never fallback to CPU): while not url and torch_url_key > 0: previous_cuda_key = torch_url_key @@ -363,7 +388,10 @@ class PytorchRequirement(SimpleSubstitution): def replace(self, req): try: - return self._replace(req) + new_req = self._replace(req) + if new_req: + self._original_req.append((req, new_req)) + return new_req except Exception as e: message = "Exception when trying to resolve python wheel" self.log.debug(message, exc_info=True) @@ -378,13 +406,13 @@ class PytorchRequirement(SimpleSubstitution): except: pass - try: - result = self._table_lookup(req) - except Exception as e: - exc = e - else: - self.log.debug('Replacing requirement "%s" with %r', req, result) - return result + # try: + # result = self._table_lookup(req) + # except Exception as e: + # exc = e + # else: + # self.log.debug('Replacing requirement "%s" with %r', req, result) + # return result self.log.debug( "Could not find Pytorch wheel in table, trying manually constructing URL" @@ -399,7 +427,7 @@ class PytorchRequirement(SimpleSubstitution): if result: self.log.debug("URL not found: {}".format(result)) exc = PytorchResolutionError( - "Was not able to find pytorch wheel URL: {}".format(exc) + "Could not find pytorch wheel URL for: {}".format(req) ) # cancel exception chaining six.raise_from(exc, None)